Skip to content

Commit

Permalink
Fix distinct error codes test (#73)
Browse files Browse the repository at this point in the history
Closes #71
  • Loading branch information
sbrugman authored Sep 5, 2024
1 parent 555bb8d commit c755731
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 20 deletions.
4 changes: 2 additions & 2 deletions tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def test_errorcodes_distinct():
LOGGER.info("Checking error code for %s", visitor.__class__.__name__)
errors = visitor.ERRORS
for e in errors:
assert e not in seen
seen.add(e)
assert e.error_code not in seen
seen.add(e.error_code)


def test_parse_error_code_str():
Expand Down
22 changes: 4 additions & 18 deletions torchfix/visitors/vision/singleton_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,8 @@ class TorchVisionSingletonImportVisitor(TorchVisitor):
TorchError(
"TOR203",
(
"Consider replacing 'import torchvision.datasets as datasets' "
"with 'from torchvision import datasets'."
),
),
TorchError(
"TOR203",
(
"Consider replacing 'import torchvision.models as models' "
"with 'from torchvision import models'."
),
),
TorchError(
"TOR203",
(
"Consider replacing 'import torchvision.transforms as transforms' "
"with 'from torchvision import transforms'."
"Consider replacing 'import torchvision.{module} as {module}' "
"with 'from torchvision import {module}'."
),
),
]
Expand Down Expand Up @@ -53,8 +39,8 @@ def visit_Import(self, node: cst.Import) -> None:
)
self.add_violation(
node,
error_code=self.ERRORS[i].error_code,
message=self.ERRORS[i].message(),
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(module=import_attr),
replacement=replacement,
)
break

0 comments on commit c755731

Please sign in to comment.