Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
soulitzer committed Jan 31, 2024
1 parent 3ccdc6a commit 1971e03
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 43 deletions.
2 changes: 1 addition & 1 deletion tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _checker_results(s):
def _codemod_results(source_path):
with open(source_path) as source:
code = source.read()
config = TorchCodemodConfig(select=GET_ALL_ERROR_CODES())
config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES()))
context = TorchCodemod(codemod.CodemodContext(filename=source_path), config)
new_module = codemod.transform_module(context, code)
return new_module.code
Expand Down
32 changes: 21 additions & 11 deletions torchfix/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,33 @@
TorchCodemodConfig,
__version__ as TorchFixVersion,
DISABLED_BY_DEFAULT,
expand_error_codes,
GET_ALL_ERROR_CODES,
)
from .common import CYAN, ENDC


def process_error_code_str(code_str):
# Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001.
# We deduplicate them here.

# Default when --select is not provided.
if code_str is None:
return [code for code in GET_ALL_ERROR_CODES()
if code not in DISABLED_BY_DEFAULT]
codes = [s.strip() for s in code_str.split(",")]
if "ALL" in codes:
return GET_ALL_ERROR_CODES()
for code in codes:
if code not in GET_ALL_ERROR_CODES():
raise ValueError(f"Invalid error code: {code}, available error "
f"codes: {GET_ALL_ERROR_CODES()}")
return codes
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
return list(GET_ALL_ERROR_CODES() - exclude_set)

raw_codes = [s.strip() for s in code_str.split(",")]

# Validate error codes
for c in raw_codes:
if len(expand_error_codes((c,))) == 0:
raise ValueError(f"Invalid error code: {c}, available error "
f"codes: {list(GET_ALL_ERROR_CODES())}")

if "ALL" in raw_codes:
return list(GET_ALL_ERROR_CODES())

return list(expand_error_codes(raw_codes))


def main() -> None:
Expand All @@ -52,7 +62,7 @@ def main() -> None:
parser.add_argument(
"--select",
help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. "
f"Available rules: {', '.join(GET_ALL_ERROR_CODES())}. "
f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. "
f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.",
type=str,
default=None,
Expand Down
79 changes: 48 additions & 31 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
import functools
from pathlib import Path
from typing import Optional, List
import libcst as cst
Expand Down Expand Up @@ -35,49 +36,65 @@
TorchReentrantCheckpointVisitor,
]

_ALL_ERROR_CODES = None


@functools.cache
def GET_ALL_ERROR_CODES():
global _ALL_ERROR_CODES
if _ALL_ERROR_CODES is None:
codes = []
for cls in ALL_VISITOR_CLS:
if isinstance(cls.ERROR_CODE, list):
codes += cls.ERROR_CODE
else:
codes.append(cls.ERROR_CODE)
_ALL_ERROR_CODES = codes
return _ALL_ERROR_CODES
codes = set()
for cls in ALL_VISITOR_CLS:
if isinstance(cls.ERROR_CODE, list):
codes |= set(cls.ERROR_CODE)
else:
codes.add(cls.ERROR_CODE)
return codes


@functools.cache
def expand_error_codes(codes):
out_codes = set()
for c_a in codes:
for c_b in GET_ALL_ERROR_CODES():
if c_b.startswith(c_a):
out_codes.add(c_b)
return out_codes


def construct_visitor(cls):
if cls is TorchDeprecatedSymbolsVisitor:
return cls(DEPRECATED_CONFIG_PATH)
else:
return cls()


def GET_ALL_VISITORS():
out = []
for v in ALL_VISITOR_CLS:
if v is TorchDeprecatedSymbolsVisitor:
out.append(v(DEPRECATED_CONFIG_PATH))
else:
out.append(v())
out.append(construct_visitor(v))
return out


def get_visitor_with_error_code(error_code):
# Each error code can only correspond to one visitor
for visitor in GET_ALL_VISITORS():
if isinstance(visitor.ERROR_CODE, list):
if error_code in visitor.ERROR_CODE:
return visitor
else:
if error_code == visitor.ERROR_CODE:
return visitor
raise AssertionError(f"Unknown error code: {error_code}")


def get_visitors_with_error_codes(error_codes):
visitors = []
visitor_classes = set()
for error_code in error_codes:
visitors.append(get_visitor_with_error_code(error_code))
return visitors
# Assume the error codes have been expanded so each error code can
# only correspond to one visitor.
found = False
for visitor_cls in ALL_VISITOR_CLS:
if isinstance(visitor_cls.ERROR_CODE, list):
if error_code in visitor_cls.ERROR_CODE:
visitor_classes.add(visitor_cls)
found = True
break
else:
if error_code == visitor_cls.ERROR_CODE:
visitor_classes.add(visitor_cls)
found = True
break
if not found:
raise AssertionError(f"Unknown error code: {error_code}")
out = []
for cls in visitor_classes:
out.append(construct_visitor(cls))
return out


# Flake8 plugin
Expand Down

0 comments on commit 1971e03

Please sign in to comment.