Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
soulitzer committed Feb 1, 2024
1 parent 375316b commit efc820a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 24 deletions.
17 changes: 17 additions & 0 deletions tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
TorchChecker,
TorchCodemod,
TorchCodemodConfig,
DISABLED_BY_DEFAULT,
expand_error_codes,
GET_ALL_VISITORS,
GET_ALL_ERROR_CODES,
process_error_code_str,
)
import logging
import libcst.codemod as codemod
Expand Down Expand Up @@ -61,3 +64,17 @@ def test_errorcodes_distinct():
for e in error_code if isinstance(error_code, list) else [error_code]:
assert e not in seen
seen.add(e)


def test_parse_error_code_str():
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
cases = [
("ALL", list(GET_ALL_ERROR_CODES())),
("ALL,TOR102", list(GET_ALL_ERROR_CODES())),
("TOR102", ["TOR102"]),
("TOR102,TOR101", ["TOR102", "TOR101"]),
("TOR1,TOR102", ["TOR102", "TOR101"]),
(None, list(GET_ALL_ERROR_CODES() - exclude_set)),
]
for case, expected in cases:
assert expected == process_error_code_str(case)
25 changes: 1 addition & 24 deletions torchfix/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,12 @@
TorchCodemodConfig,
__version__ as TorchFixVersion,
DISABLED_BY_DEFAULT,
expand_error_codes,
GET_ALL_ERROR_CODES,
process_error_code_str,
)
from .common import CYAN, ENDC


def process_error_code_str(code_str):
# Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001.
# We deduplicate them here.

# Default when --select is not provided.
if code_str is None:
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
return list(GET_ALL_ERROR_CODES() - exclude_set)

raw_codes = [s.strip() for s in code_str.split(",")]

# Validate error codes
for c in raw_codes:
if len(expand_error_codes((c,))) == 0:
raise ValueError(f"Invalid error code: {c}, available error "
f"codes: {list(GET_ALL_ERROR_CODES())}")

if "ALL" in raw_codes:
return list(GET_ALL_ERROR_CODES())

return list(expand_error_codes(raw_codes))


# Should get rid of this code eventually.
@contextlib.contextmanager
def StderrSilencer(redirect: bool = True):
Expand Down
25 changes: 25 additions & 0 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,31 @@ def get_visitors_with_error_codes(error_codes):
return out


def process_error_code_str(code_str):
# Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001.
# We deduplicate them here.

# Default when --select is not provided.
if code_str is None:
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
return list(GET_ALL_ERROR_CODES() - exclude_set)

raw_codes = [s.strip() for s in code_str.split(",")]

# Validate error codes
for c in raw_codes:
if c == "ALL":
continue
if len(expand_error_codes((c,))) == 0:
raise ValueError(f"Invalid error code: {c}, available error "
f"codes: {list(GET_ALL_ERROR_CODES())}")

if "ALL" in raw_codes:
return list(GET_ALL_ERROR_CODES())

return list(expand_error_codes(tuple(raw_codes)))


# Flake8 plugin
class TorchChecker:
name = "TorchFix"
Expand Down

0 comments on commit efc820a

Please sign in to comment.