Skip to content

Commit

Permalink
Add ReplaceImportsTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
kit1980 committed Mar 9, 2024
1 parent 8846f5c commit 8447683
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 46 deletions.
2 changes: 2 additions & 0 deletions tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def _codemod_results(source_path):
config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES()))
context = TorchCodemod(codemod.CodemodContext(filename=source_path), config)
new_module = codemod.transform_module(context, code)
if isinstance(new_module, codemod.TransformFailure):
raise new_module.error
return new_module.code


Expand Down
50 changes: 47 additions & 3 deletions torchfix/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass
import dataclasses
import sys
import libcst as cst
from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider
from libcst.codemod.visitors import ImportItem
from typing import Optional, List, Set, Tuple, Union
from typing import Optional, List, Set, Tuple, Union, Dict
from abc import ABC

IS_TTY = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
Expand All @@ -13,7 +13,7 @@
ENDC = "\033[0m" if IS_TTY else ""


@dataclass
@dataclasses.dataclass(frozen=True)
class LintViolation:
error_code: str
message: str
Expand All @@ -34,6 +34,49 @@ def codemod_result(self) -> str:
return f"{position} {error_code}{fixable} {self.message}"


@dataclasses.dataclass(frozen=True)
class ToReplaceImportItem:
old_module: str
old_names: Tuple[str, ...]
new_module: str


class ReplaceImportsTransformer(cst.CSTTransformer):
def __init__(self, to_replace_imports: Set[ToReplaceImportItem]) -> None:
super().__init__()
self.changed = False

# Merge all items with the same old_module.
self.to_replace_imports: Dict[str, ToReplaceImportItem] = {}
for item in to_replace_imports:
if item.old_module in self.to_replace_imports:
existing_item = self.to_replace_imports[item.old_module]
# Assert no different new_module for the same old_module.
assert item.new_module == existing_item.new_module
merged_old_names = existing_item.old_names + item.old_names
existing_item = dataclasses.replace(
existing_item, old_names=merged_old_names
)
else:
self.to_replace_imports[item.old_module] = item

def leave_ImportFrom(
self, node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom:
if node.module is not None:
module = cst.helpers.get_full_name_for_node(node.module)
if module in self.to_replace_imports:
replace_item = self.to_replace_imports[module]
if all(
name.name.value in replace_item.old_names for name in node.names
):
self.changed = True
return updated_node.with_changes(
module=cst.parse_expression(replace_item.new_module)
)
return updated_node


class TorchVisitor(cst.BatchableCSTVisitor, ABC):
METADATA_DEPENDENCIES = (
QualifiedNameProvider,
Expand All @@ -45,6 +88,7 @@ class TorchVisitor(cst.BatchableCSTVisitor, ABC):
def __init__(self) -> None:
self.violations: List[LintViolation] = []
self.needed_imports: Set[ImportItem] = set()
self.to_replace_imports: Set[ToReplaceImportItem] = set()

@staticmethod
def get_specific_arg(
Expand Down
25 changes: 13 additions & 12 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from dataclasses import dataclass
import functools
from pathlib import Path
from typing import Optional, List
from typing import Optional, List, Set
import libcst as cst
import libcst.codemod as codemod

from .common import deep_multi_replace
from .visitors.deprecated_symbols import (
TorchDeprecatedSymbolsVisitor,
_UpdateFunctorchImports,
)
from .common import deep_multi_replace, ReplaceImportsTransformer, ToReplaceImportItem
from .visitors.deprecated_symbols import TorchDeprecatedSymbolsVisitor

from .visitors.internal import TorchScopedLibraryVisitor

from .visitors.performance import TorchSynchronizedDataLoaderVisitor
from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor)
from .visitors.misc import TorchRequireGradVisitor, TorchReentrantCheckpointVisitor

from .visitors.vision import (
TorchVisionDeprecatedPretrainedVisitor,
Expand Down Expand Up @@ -118,8 +115,10 @@ def process_error_code_str(code_str):
if c == "ALL":
continue
if len(expand_error_codes((c,))) == 0:
raise ValueError(f"Invalid error code: {c}, available error "
f"codes: {list(GET_ALL_ERROR_CODES())}")
raise ValueError(
f"Invalid error code: {c}, available error "
f"codes: {list(GET_ALL_ERROR_CODES())}"
)

if "ALL" in raw_codes:
return GET_ALL_ERROR_CODES()
Expand Down Expand Up @@ -192,10 +191,12 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:

violations = []
needed_imports = []
to_replace_imports: Set[ToReplaceImportItem] = set()
wrapped_module.visit_batched(visitors)
for v in visitors:
violations += v.violations
needed_imports += v.needed_imports
to_replace_imports.update(v.to_replace_imports)

fixes_count = 0
replacement_map = {}
Expand Down Expand Up @@ -228,10 +229,10 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
)
new_module = new_module.visit(add_imports_visitor)

update_functorch_imports_visitor = _UpdateFunctorchImports()
new_module = new_module.visit(update_functorch_imports_visitor)
replace_imports_transformer = ReplaceImportsTransformer(to_replace_imports)
new_module = new_module.visit(replace_imports_transformer)

if fixes_count == 0 and not update_functorch_imports_visitor.changed:
if fixes_count == 0 and not replace_imports_transformer.changed:
raise codemod.SkipFile("No changes")

return new_module
47 changes: 16 additions & 31 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import libcst as cst
import yaml
from typing import Optional
from collections.abc import Sequence

from ...common import (
TorchVisitor,
call_with_name_changes,
ToReplaceImportItem,
LintViolation,
)

Expand All @@ -17,6 +16,20 @@

class TorchDeprecatedSymbolsVisitor(TorchVisitor):
ERROR_CODE = ["TOR001", "TOR101"]
FUCTORCH_REPLACE_IMPORTS = ToReplaceImportItem(
"functorch",
(
"vmap",
"grad",
"vjp",
"jvp",
"jacrev",
"jacfwd",
"hessian",
"functionalize",
),
"torch.func",
)

def __init__(self, deprecated_config_path=None):
def read_deprecated_config(path=None):
Expand All @@ -29,6 +42,7 @@ def read_deprecated_config(path=None):

super().__init__()
self.deprecated_config = read_deprecated_config(deprecated_config_path)
self.to_replace_imports.add(self.FUCTORCH_REPLACE_IMPORTS)

def _call_replacement(
self, node: cst.Call, qualified_name: str
Expand Down Expand Up @@ -88,32 +102,3 @@ def visit_Call(self, node):
replacement=replacement,
)
)


# TODO: refactor/generalize this.
class _UpdateFunctorchImports(cst.CSTTransformer):
REPLACEMENTS = {
"vmap",
"grad",
"vjp",
"jvp",
"jacrev",
"jacfwd",
"hessian",
"functionalize",
}

def __init__(self):
self.changed = False

def leave_ImportFrom(
self, node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom:
if (
getattr(node.module, "value", None) == "functorch"
and isinstance(node.names, Sequence)
and all(name.name.value in self.REPLACEMENTS for name in node.names)
):
self.changed = True
return updated_node.with_changes(module=cst.parse_expression("torch.func"))
return updated_node

0 comments on commit 8447683

Please sign in to comment.