Skip to content

Commit

Permalink
Update arg to accept specific rules
Browse files Browse the repository at this point in the history
  • Loading branch information
soulitzer committed Jan 29, 2024
1 parent 9e2c8e6 commit 6ad8d5a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 25 deletions.
23 changes: 17 additions & 6 deletions torchfix/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,19 @@
import sys
import io

from .torchfix import TorchCodemod, TorchCodemodConfig
from .torchfix import TorchCodemod, TorchCodemodConfig, DISABLED_BY_DEFAULT, GET_ALL_ERROR_CODES
from .common import CYAN, ENDC

def process_error_code_str(code_str):
if code_str is None:
return [code for code in GET_ALL_ERROR_CODES() if code not in DISABLED_BY_DEFAULT]
codes = [s.strip() for s in code_str.split(",")]
if "ALL" in codes:
return GET_ALL_ERROR_CODES()
for code in codes:
if code not in GET_ALL_ERROR_CODES():
raise ValueError(f"Invalid error code: {code}, available error codes: {GET_ALL_ERROR_CODES()}")
return codes

def main() -> None:
parser = argparse.ArgumentParser()
Expand All @@ -31,10 +41,11 @@ def main() -> None:
)
parser.add_argument(
"--select",
help="ALL to enable rules disabled by default",
choices=[
"ALL",
],
help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. "
f"Available rules: {', '.join(GET_ALL_ERROR_CODES())}. "
f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.",
type=str,
default=None,
)

# XXX TODO: Get rid of this!
Expand All @@ -61,7 +72,7 @@ def main() -> None:
break

config = TorchCodemodConfig()
config.select = args.select
config.select = process_error_code_str(args.select)
command_instance = TorchCodemod(codemod.CodemodContext(), config)
DIFF_CONTEXT = 5
try:
Expand Down
76 changes: 57 additions & 19 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from typing import Optional, List
import libcst as cst
import libcst.codemod as codemod

Expand All @@ -25,18 +25,55 @@

DISABLED_BY_DEFAULT = ["TOR3", "TOR4"]

ALL_VISITOR_CLS = [
TorchDeprecatedSymbolsVisitor,
TorchRequireGradVisitor,
TorchSynchronizedDataLoaderVisitor,
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchUnsafeLoadVisitor,
TorchReentrantCheckpointVisitor,
]

_ALL_ERROR_CODES = None

def GET_ALL_ERROR_CODES():
global _ALL_ERROR_CODES
if _ALL_ERROR_CODES is None:
codes = []
for cls in ALL_VISITOR_CLS:
if isinstance(cls.ERROR_CODE, list):
codes += cls.ERROR_CODE
else:
codes.append(cls.ERROR_CODE)
_ALL_ERROR_CODES = codes
return _ALL_ERROR_CODES

def GET_ALL_VISITORS():
return [
TorchDeprecatedSymbolsVisitor(DEPRECATED_CONFIG_PATH),
TorchRequireGradVisitor(),
TorchSynchronizedDataLoaderVisitor(),
TorchVisionDeprecatedPretrainedVisitor(),
TorchVisionDeprecatedToTensorVisitor(),
TorchUnsafeLoadVisitor(),
TorchReentrantCheckpointVisitor(),
]

out = []
for v in ALL_VISITOR_CLS:
if v is TorchDeprecatedSymbolsVisitor:
out.append(v(DEPRECATED_CONFIG_PATH))
else:
out.append(v())
return out

def get_visitor_with_error_code(error_code):
# Each error code can only correspond to one visitor
for visitor in GET_ALL_VISITORS():
if isinstance(visitor.ERROR_CODE, list):
if error_code in visitor.ERROR_CODE:
return visitor
else:
if error_code == visitor.ERROR_CODE:
return visitor
assert False, f"Unknown error code: {error_code}"

def get_visitors_with_error_codes(error_codes):
visitors = []
for error_code in error_codes:
visitors.append(get_visitor_with_error_code(error_code))
return visitors

# Flake8 plugin
class TorchChecker:
Expand Down Expand Up @@ -78,7 +115,7 @@ def add_options(optmanager):
# Standalone torchfix command
@dataclass
class TorchCodemodConfig:
select: Optional[str] = None
select: List[str] = None


class TorchCodemod(codemod.Codemod):
Expand All @@ -97,8 +134,8 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
# in that case we would need to use `wrapped_module.module`
# instead of `module`.
wrapped_module = cst.MetadataWrapper(module, unsafe_skip_copy=True)
visitors = get_visitors_with_error_codes(self.config.select)

visitors = GET_ALL_VISITORS()
violations = []
needed_imports = []
wrapped_module.visit_batched(visitors)
Expand All @@ -110,12 +147,13 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
replacement_map = {}
assert self.context.filename is not None
for violation in violations:
skip_violation = False
if self.config is None or self.config.select != "ALL":
for disabled_code in DISABLED_BY_DEFAULT:
if violation.error_code.startswith(disabled_code):
skip_violation = True
break
# Still need to skip violations here, since a single visitor can
# correspond to multiple different types of violations.
skip_violation = True
for code in self.config.select:
if violation.error_code.startswith(code):
skip_violation = False
break
if skip_violation:
continue

Expand Down

0 comments on commit 6ad8d5a

Please sign in to comment.