From f16567ab48d8d424196ffb802ed228c1be4b04ce Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Fri, 8 Mar 2024 16:53:18 -0800 Subject: [PATCH] Add ReplaceImportsTransformer --- tests/test_torchfix.py | 2 + torchfix/common.py | 50 +++++++++++++++++-- torchfix/torchfix.py | 25 +++++----- .../visitors/deprecated_symbols/__init__.py | 47 ++++++----------- 4 files changed, 78 insertions(+), 46 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index cd9b74c..6e7e0c6 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -27,6 +27,8 @@ def _codemod_results(source_path): config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES())) context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) new_module = codemod.transform_module(context, code) + if isinstance(new_module, codemod.TransformFailure): + raise new_module.error return new_module.code diff --git a/torchfix/common.py b/torchfix/common.py index b302346..2d8cda5 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -1,9 +1,9 @@ -from dataclasses import dataclass +import dataclasses import sys import libcst as cst from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider from libcst.codemod.visitors import ImportItem -from typing import Optional, List, Set, Tuple, Union +from typing import Optional, List, Set, Tuple, Union, Dict, Sequence from abc import ABC IS_TTY = hasattr(sys.stdout, "isatty") and sys.stdout.isatty() @@ -13,7 +13,7 @@ ENDC = "\033[0m" if IS_TTY else "" -@dataclass +@dataclasses.dataclass(frozen=True) class LintViolation: error_code: str message: str @@ -34,6 +34,49 @@ def codemod_result(self) -> str: return f"{position} {error_code}{fixable} {self.message}" +@dataclasses.dataclass(frozen=True) +class ToReplaceImportItem: + old_module: str + old_names: Tuple[str, ...] + new_module: str + + +class ReplaceImportsTransformer(cst.CSTTransformer): + def __init__(self, to_replace_imports: Set[ToReplaceImportItem]) -> None: + super().__init__() + self.changed = False + + # Merge all items with the same old_module. + self.to_replace_imports: Dict[str, ToReplaceImportItem] = {} + for item in to_replace_imports: + if item.old_module in self.to_replace_imports: + existing_item = self.to_replace_imports[item.old_module] + # Assert no different new_module for the same old_module. + assert item.new_module == existing_item.new_module + merged_old_names = existing_item.old_names + item.old_names + existing_item = dataclasses.replace( + existing_item, old_names=merged_old_names + ) + else: + self.to_replace_imports[item.old_module] = item + + def leave_ImportFrom( + self, node: cst.ImportFrom, updated_node: cst.ImportFrom + ) -> cst.ImportFrom: + if node.module is not None: + module = cst.helpers.get_full_name_for_node(node.module) + if module in self.to_replace_imports: + replace_item = self.to_replace_imports[module] + if isinstance(node.names, Sequence) and all( + name.name.value in replace_item.old_names for name in node.names + ): + self.changed = True + return updated_node.with_changes( + module=cst.parse_expression(replace_item.new_module) + ) + return updated_node + + class TorchVisitor(cst.BatchableCSTVisitor, ABC): METADATA_DEPENDENCIES = ( QualifiedNameProvider, @@ -45,6 +88,7 @@ class TorchVisitor(cst.BatchableCSTVisitor, ABC): def __init__(self) -> None: self.violations: List[LintViolation] = [] self.needed_imports: Set[ImportItem] = set() + self.to_replace_imports: Set[ToReplaceImportItem] = set() @staticmethod def get_specific_arg( diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index a38d81d..dd9b26d 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -1,20 +1,17 @@ from dataclasses import dataclass import functools from pathlib import Path -from typing import Optional, List +from typing import Optional, List, Set import libcst as cst import libcst.codemod as codemod -from .common import deep_multi_replace -from .visitors.deprecated_symbols import ( - TorchDeprecatedSymbolsVisitor, - _UpdateFunctorchImports, -) +from .common import deep_multi_replace, ReplaceImportsTransformer, ToReplaceImportItem +from .visitors.deprecated_symbols import TorchDeprecatedSymbolsVisitor from .visitors.internal import TorchScopedLibraryVisitor from .visitors.performance import TorchSynchronizedDataLoaderVisitor -from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor) +from .visitors.misc import TorchRequireGradVisitor, TorchReentrantCheckpointVisitor from .visitors.vision import ( TorchVisionDeprecatedPretrainedVisitor, @@ -118,8 +115,10 @@ def process_error_code_str(code_str): if c == "ALL": continue if len(expand_error_codes((c,))) == 0: - raise ValueError(f"Invalid error code: {c}, available error " - f"codes: {list(GET_ALL_ERROR_CODES())}") + raise ValueError( + f"Invalid error code: {c}, available error " + f"codes: {list(GET_ALL_ERROR_CODES())}" + ) if "ALL" in raw_codes: return GET_ALL_ERROR_CODES() @@ -192,10 +191,12 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: violations = [] needed_imports = [] + to_replace_imports: Set[ToReplaceImportItem] = set() wrapped_module.visit_batched(visitors) for v in visitors: violations += v.violations needed_imports += v.needed_imports + to_replace_imports.update(v.to_replace_imports) fixes_count = 0 replacement_map = {} @@ -228,10 +229,10 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: ) new_module = new_module.visit(add_imports_visitor) - update_functorch_imports_visitor = _UpdateFunctorchImports() - new_module = new_module.visit(update_functorch_imports_visitor) + replace_imports_transformer = ReplaceImportsTransformer(to_replace_imports) + new_module = new_module.visit(replace_imports_transformer) - if fixes_count == 0 and not update_functorch_imports_visitor.changed: + if fixes_count == 0 and not replace_imports_transformer.changed: raise codemod.SkipFile("No changes") return new_module diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index 93a9082..1afbe93 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -1,11 +1,10 @@ import libcst as cst import yaml from typing import Optional -from collections.abc import Sequence - from ...common import ( TorchVisitor, call_with_name_changes, + ToReplaceImportItem, LintViolation, ) @@ -17,6 +16,20 @@ class TorchDeprecatedSymbolsVisitor(TorchVisitor): ERROR_CODE = ["TOR001", "TOR101"] + FUCTORCH_REPLACE_IMPORTS = ToReplaceImportItem( + "functorch", + ( + "vmap", + "grad", + "vjp", + "jvp", + "jacrev", + "jacfwd", + "hessian", + "functionalize", + ), + "torch.func", + ) def __init__(self, deprecated_config_path=None): def read_deprecated_config(path=None): @@ -29,6 +42,7 @@ def read_deprecated_config(path=None): super().__init__() self.deprecated_config = read_deprecated_config(deprecated_config_path) + self.to_replace_imports.add(self.FUCTORCH_REPLACE_IMPORTS) def _call_replacement( self, node: cst.Call, qualified_name: str @@ -88,32 +102,3 @@ def visit_Call(self, node): replacement=replacement, ) ) - - -# TODO: refactor/generalize this. -class _UpdateFunctorchImports(cst.CSTTransformer): - REPLACEMENTS = { - "vmap", - "grad", - "vjp", - "jvp", - "jacrev", - "jacfwd", - "hessian", - "functionalize", - } - - def __init__(self): - self.changed = False - - def leave_ImportFrom( - self, node: cst.ImportFrom, updated_node: cst.ImportFrom - ) -> cst.ImportFrom: - if ( - getattr(node.module, "value", None) == "functorch" - and isinstance(node.names, Sequence) - and all(name.name.value in self.REPLACEMENTS for name in node.names) - ): - self.changed = True - return updated_node.with_changes(module=cst.parse_expression("torch.func")) - return updated_node