Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update --select arg to accept specific rules #16

Merged
merged 6 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
TorchChecker,
TorchCodemod,
TorchCodemodConfig,
DISABLED_BY_DEFAULT,
expand_error_codes,
GET_ALL_VISITORS,
GET_ALL_ERROR_CODES,
process_error_code_str,
)
import logging
import libcst.codemod as codemod
Expand All @@ -20,7 +24,7 @@ def _checker_results(s):
def _codemod_results(source_path):
with open(source_path) as source:
code = source.read()
config = TorchCodemodConfig(select="ALL")
config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES()))
context = TorchCodemod(codemod.CodemodContext(filename=source_path), config)
new_module = codemod.transform_module(context, code)
return new_module.code
Expand Down Expand Up @@ -60,3 +64,17 @@ def test_errorcodes_distinct():
for e in error_code if isinstance(error_code, list) else [error_code]:
assert e not in seen
seen.add(e)


def test_parse_error_code_str():
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
cases = [
("ALL", GET_ALL_ERROR_CODES()),
("ALL,TOR102", GET_ALL_ERROR_CODES()),
("TOR102", {"TOR102"}),
("TOR102,TOR101", {"TOR102", "TOR101"}),
("TOR1,TOR102", {"TOR102", "TOR101"}),
(None, GET_ALL_ERROR_CODES() - exclude_set),
]
for case, expected in cases:
assert expected == process_error_code_str(case)
20 changes: 14 additions & 6 deletions torchfix/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
import sys
import io

from .torchfix import TorchCodemod, TorchCodemodConfig, __version__ as TorchFixVersion
from .torchfix import (
TorchCodemod,
TorchCodemodConfig,
__version__ as TorchFixVersion,
DISABLED_BY_DEFAULT,
GET_ALL_ERROR_CODES,
process_error_code_str,
)
from .common import CYAN, ENDC


Expand Down Expand Up @@ -55,10 +62,11 @@ def main() -> None:
)
parser.add_argument(
"--select",
help="ALL to enable rules disabled by default",
choices=[
"ALL",
],
help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. "
f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. "
f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.",
type=str,
default=None,
)
parser.add_argument(
"--version",
Expand Down Expand Up @@ -94,7 +102,7 @@ def main() -> None:
break

config = TorchCodemodConfig()
config.select = args.select
config.select = list(process_error_code_str(args.select))
command_instance = TorchCodemod(codemod.CodemodContext(), config)
DIFF_CONTEXT = 5
try:
Expand Down
123 changes: 105 additions & 18 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
import functools
from pathlib import Path
from typing import Optional
from typing import Optional, List
import libcst as cst
import libcst.codemod as codemod

Expand All @@ -25,17 +26,100 @@

DISABLED_BY_DEFAULT = ["TOR3", "TOR4"]

ALL_VISITOR_CLS = [
TorchDeprecatedSymbolsVisitor,
TorchRequireGradVisitor,
TorchSynchronizedDataLoaderVisitor,
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchUnsafeLoadVisitor,
TorchReentrantCheckpointVisitor,
]


@functools.cache
def GET_ALL_ERROR_CODES():
codes = set()
for cls in ALL_VISITOR_CLS:
if isinstance(cls.ERROR_CODE, list):
codes |= set(cls.ERROR_CODE)
else:
codes.add(cls.ERROR_CODE)
return codes


@functools.cache
def expand_error_codes(codes):
out_codes = set()
for c_a in codes:
for c_b in GET_ALL_ERROR_CODES():
if c_b.startswith(c_a):
out_codes.add(c_b)
return out_codes


def construct_visitor(cls):
if cls is TorchDeprecatedSymbolsVisitor:
return cls(DEPRECATED_CONFIG_PATH)
else:
return cls()


def GET_ALL_VISITORS():
return [
TorchDeprecatedSymbolsVisitor(DEPRECATED_CONFIG_PATH),
TorchRequireGradVisitor(),
TorchSynchronizedDataLoaderVisitor(),
TorchVisionDeprecatedPretrainedVisitor(),
TorchVisionDeprecatedToTensorVisitor(),
TorchUnsafeLoadVisitor(),
TorchReentrantCheckpointVisitor(),
]
out = []
for v in ALL_VISITOR_CLS:
out.append(construct_visitor(v))
return out


def get_visitors_with_error_codes(error_codes):
visitor_classes = set()
for error_code in error_codes:
# Assume the error codes have been expanded so each error code can
# only correspond to one visitor.
found = False
for visitor_cls in ALL_VISITOR_CLS:
if isinstance(visitor_cls.ERROR_CODE, list):
if error_code in visitor_cls.ERROR_CODE:
visitor_classes.add(visitor_cls)
found = True
break
else:
if error_code == visitor_cls.ERROR_CODE:
visitor_classes.add(visitor_cls)
found = True
break
if not found:
raise AssertionError(f"Unknown error code: {error_code}")
out = []
for cls in visitor_classes:
out.append(construct_visitor(cls))
return out


def process_error_code_str(code_str):
# Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001.
# We deduplicate them here.

# Default when --select is not provided.
if code_str is None:
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
return GET_ALL_ERROR_CODES() - exclude_set

raw_codes = [s.strip() for s in code_str.split(",")]

# Validate error codes
for c in raw_codes:
if c == "ALL":
continue
if len(expand_error_codes((c,))) == 0:
raise ValueError(f"Invalid error code: {c}, available error "
f"codes: {list(GET_ALL_ERROR_CODES())}")

if "ALL" in raw_codes:
return GET_ALL_ERROR_CODES()

return expand_error_codes(tuple(raw_codes))


# Flake8 plugin
Expand Down Expand Up @@ -78,7 +162,7 @@ def add_options(optmanager):
# Standalone torchfix command
@dataclass
class TorchCodemodConfig:
select: Optional[str] = None
select: Optional[List[str]] = None


class TorchCodemod(codemod.Codemod):
Expand All @@ -97,8 +181,10 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
# in that case we would need to use `wrapped_module.module`
# instead of `module`.
wrapped_module = cst.MetadataWrapper(module, unsafe_skip_copy=True)
if self.config is None or self.config.select is None:
raise AssertionError("Expected self.config.select to be set")
visitors = get_visitors_with_error_codes(self.config.select)

visitors = GET_ALL_VISITORS()
violations = []
needed_imports = []
wrapped_module.visit_batched(visitors)
Expand All @@ -110,12 +196,13 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
replacement_map = {}
assert self.context.filename is not None
for violation in violations:
skip_violation = False
if self.config is None or self.config.select != "ALL":
for disabled_code in DISABLED_BY_DEFAULT:
if violation.error_code.startswith(disabled_code):
skip_violation = True
break
# Still need to skip violations here, since a single visitor can
# correspond to multiple different types of violations.
skip_violation = True
for code in self.config.select:
if violation.error_code.startswith(code):
skip_violation = False
break
if skip_violation:
continue

Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/vision/to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def visit_ImportFrom(self, node):

def visit_Attribute(self, node):
qualified_names = self.get_metadata(cst.metadata.QualifiedNameProvider, node)
if not len(qualified_names) == 1:
if len(qualified_names) != 1:
return

self._maybe_add_violation(qualified_names.pop().name, node)
Loading