Skip to content

Commit

Permalink
Update --select arg to accept specific rules (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
soulitzer authored Feb 1, 2024
1 parent 92136fe commit f219838
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 26 deletions.
20 changes: 19 additions & 1 deletion tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
TorchChecker,
TorchCodemod,
TorchCodemodConfig,
DISABLED_BY_DEFAULT,
expand_error_codes,
GET_ALL_VISITORS,
GET_ALL_ERROR_CODES,
process_error_code_str,
)
import logging
import libcst.codemod as codemod
Expand All @@ -20,7 +24,7 @@ def _checker_results(s):
def _codemod_results(source_path):
with open(source_path) as source:
code = source.read()
config = TorchCodemodConfig(select="ALL")
config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES()))
context = TorchCodemod(codemod.CodemodContext(filename=source_path), config)
new_module = codemod.transform_module(context, code)
return new_module.code
Expand Down Expand Up @@ -60,3 +64,17 @@ def test_errorcodes_distinct():
for e in error_code if isinstance(error_code, list) else [error_code]:
assert e not in seen
seen.add(e)


def test_parse_error_code_str():
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
cases = [
("ALL", GET_ALL_ERROR_CODES()),
("ALL,TOR102", GET_ALL_ERROR_CODES()),
("TOR102", {"TOR102"}),
("TOR102,TOR101", {"TOR102", "TOR101"}),
("TOR1,TOR102", {"TOR102", "TOR101"}),
(None, GET_ALL_ERROR_CODES() - exclude_set),
]
for case, expected in cases:
assert expected == process_error_code_str(case)
20 changes: 14 additions & 6 deletions torchfix/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
import sys
import io

from .torchfix import TorchCodemod, TorchCodemodConfig, __version__ as TorchFixVersion
from .torchfix import (
TorchCodemod,
TorchCodemodConfig,
__version__ as TorchFixVersion,
DISABLED_BY_DEFAULT,
GET_ALL_ERROR_CODES,
process_error_code_str,
)
from .common import CYAN, ENDC


Expand Down Expand Up @@ -55,10 +62,11 @@ def main() -> None:
)
parser.add_argument(
"--select",
help="ALL to enable rules disabled by default",
choices=[
"ALL",
],
help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. "
f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. "
f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.",
type=str,
default=None,
)
parser.add_argument(
"--version",
Expand Down Expand Up @@ -94,7 +102,7 @@ def main() -> None:
break

config = TorchCodemodConfig()
config.select = args.select
config.select = list(process_error_code_str(args.select))
command_instance = TorchCodemod(codemod.CodemodContext(), config)
DIFF_CONTEXT = 5
try:
Expand Down
123 changes: 105 additions & 18 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
import functools
from pathlib import Path
from typing import Optional
from typing import Optional, List
import libcst as cst
import libcst.codemod as codemod

Expand All @@ -25,17 +26,100 @@

DISABLED_BY_DEFAULT = ["TOR3", "TOR4"]

ALL_VISITOR_CLS = [
TorchDeprecatedSymbolsVisitor,
TorchRequireGradVisitor,
TorchSynchronizedDataLoaderVisitor,
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchUnsafeLoadVisitor,
TorchReentrantCheckpointVisitor,
]


@functools.cache
def GET_ALL_ERROR_CODES():
codes = set()
for cls in ALL_VISITOR_CLS:
if isinstance(cls.ERROR_CODE, list):
codes |= set(cls.ERROR_CODE)
else:
codes.add(cls.ERROR_CODE)
return codes


@functools.cache
def expand_error_codes(codes):
out_codes = set()
for c_a in codes:
for c_b in GET_ALL_ERROR_CODES():
if c_b.startswith(c_a):
out_codes.add(c_b)
return out_codes


def construct_visitor(cls):
if cls is TorchDeprecatedSymbolsVisitor:
return cls(DEPRECATED_CONFIG_PATH)
else:
return cls()


def GET_ALL_VISITORS():
return [
TorchDeprecatedSymbolsVisitor(DEPRECATED_CONFIG_PATH),
TorchRequireGradVisitor(),
TorchSynchronizedDataLoaderVisitor(),
TorchVisionDeprecatedPretrainedVisitor(),
TorchVisionDeprecatedToTensorVisitor(),
TorchUnsafeLoadVisitor(),
TorchReentrantCheckpointVisitor(),
]
out = []
for v in ALL_VISITOR_CLS:
out.append(construct_visitor(v))
return out


def get_visitors_with_error_codes(error_codes):
visitor_classes = set()
for error_code in error_codes:
# Assume the error codes have been expanded so each error code can
# only correspond to one visitor.
found = False
for visitor_cls in ALL_VISITOR_CLS:
if isinstance(visitor_cls.ERROR_CODE, list):
if error_code in visitor_cls.ERROR_CODE:
visitor_classes.add(visitor_cls)
found = True
break
else:
if error_code == visitor_cls.ERROR_CODE:
visitor_classes.add(visitor_cls)
found = True
break
if not found:
raise AssertionError(f"Unknown error code: {error_code}")
out = []
for cls in visitor_classes:
out.append(construct_visitor(cls))
return out


def process_error_code_str(code_str):
# Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001.
# We deduplicate them here.

# Default when --select is not provided.
if code_str is None:
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
return GET_ALL_ERROR_CODES() - exclude_set

raw_codes = [s.strip() for s in code_str.split(",")]

# Validate error codes
for c in raw_codes:
if c == "ALL":
continue
if len(expand_error_codes((c,))) == 0:
raise ValueError(f"Invalid error code: {c}, available error "
f"codes: {list(GET_ALL_ERROR_CODES())}")

if "ALL" in raw_codes:
return GET_ALL_ERROR_CODES()

return expand_error_codes(tuple(raw_codes))


# Flake8 plugin
Expand Down Expand Up @@ -78,7 +162,7 @@ def add_options(optmanager):
# Standalone torchfix command
@dataclass
class TorchCodemodConfig:
select: Optional[str] = None
select: Optional[List[str]] = None


class TorchCodemod(codemod.Codemod):
Expand All @@ -97,8 +181,10 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
# in that case we would need to use `wrapped_module.module`
# instead of `module`.
wrapped_module = cst.MetadataWrapper(module, unsafe_skip_copy=True)
if self.config is None or self.config.select is None:
raise AssertionError("Expected self.config.select to be set")
visitors = get_visitors_with_error_codes(self.config.select)

visitors = GET_ALL_VISITORS()
violations = []
needed_imports = []
wrapped_module.visit_batched(visitors)
Expand All @@ -110,12 +196,13 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
replacement_map = {}
assert self.context.filename is not None
for violation in violations:
skip_violation = False
if self.config is None or self.config.select != "ALL":
for disabled_code in DISABLED_BY_DEFAULT:
if violation.error_code.startswith(disabled_code):
skip_violation = True
break
# Still need to skip violations here, since a single visitor can
# correspond to multiple different types of violations.
skip_violation = True
for code in self.config.select:
if violation.error_code.startswith(code):
skip_violation = False
break
if skip_violation:
continue

Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/vision/to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def visit_ImportFrom(self, node):

def visit_Attribute(self, node):
qualified_names = self.get_metadata(cst.metadata.QualifiedNameProvider, node)
if not len(qualified_names) == 1:
if len(qualified_names) != 1:
return

self._maybe_add_violation(qualified_names.pop().name, node)

0 comments on commit f219838

Please sign in to comment.