Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: extend deprecation check to references #81

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 56 additions & 34 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import yaml
from typing import Optional, List

import libcst as cst
from libcst.metadata import QualifiedNameProvider

from ...common import (
TorchVisitor,
TorchError,
TorchVisitor,
call_with_name_changes,
check_old_names_in_import_from,
)

from .range import call_replacement_range
from .cholesky import call_replacement_cholesky
from .chain_matmul import call_replacement_chain_matmul
from .cholesky import call_replacement_cholesky
from .qr import call_replacement_qr
from .range import call_replacement_range


class TorchDeprecatedSymbolsVisitor(TorchVisitor):
Expand All @@ -22,6 +24,8 @@ class TorchDeprecatedSymbolsVisitor(TorchVisitor):
TorchError("TOR101", "Use of deprecated function {old_name}"),
TorchError("TOR004", "Import of removed function {old_name}"),
TorchError("TOR103", "Import of deprecated function {old_name}"),
TorchError("TOR005", "Reference to removed function {old_name}"),
TorchError("TOR106", "Reference to deprecated function {old_name}"),
]

def __init__(self, deprecated_config_path=None):
Expand All @@ -36,10 +40,11 @@ def read_deprecated_config(path=None):

super().__init__()
self.deprecated_config = read_deprecated_config(deprecated_config_path)
self.old_new_name_map = {
self.replacements = {
name: self.deprecated_config[name].get("replacement")
for name in self.deprecated_config
}
self.in_call = False

def _call_replacement(
self, node: cst.Call, qualified_name: str
Expand All @@ -50,11 +55,12 @@ def _call_replacement(
"torch.chain_matmul": call_replacement_chain_matmul,
"torch.qr": call_replacement_qr,
}
replacement = None

if qualified_name in replacements_map:
return replacements_map[qualified_name](node)

replacement = None

# Replace names for functions that have drop-in replacement.
function_name_replacement = self.deprecated_config.get(qualified_name, {}).get(
"replacement", ""
Expand All @@ -68,24 +74,27 @@ def _call_replacement(
self.needed_imports.update(imports)
return replacement

def _construct_error(self, qualified_name, deprecated_key, removed_key):
if "remove_pr" not in self.deprecated_config[qualified_name]:
error_code = self.ERRORS[deprecated_key].error_code
message = self.ERRORS[deprecated_key].message(old_name=qualified_name)
else:
error_code = self.ERRORS[removed_key].error_code
message = self.ERRORS[removed_key].message(old_name=qualified_name)

reference = self.deprecated_config[qualified_name].get("reference")
if reference is not None:
message = f"{message}: {reference}"

return error_code, message

def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
if node.module is None:
return

old_names, replacement = check_old_names_in_import_from(
node, self.old_new_name_map
)
old_names, replacement = check_old_names_in_import_from(node, self.replacements)
for qualified_name in old_names:
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERRORS[3].error_code
message = self.ERRORS[3].message(old_name=qualified_name)
else:
error_code = self.ERRORS[2].error_code
message = self.ERRORS[2].message(old_name=qualified_name)

reference = self.deprecated_config[qualified_name].get("reference")
if reference is not None:
message = f"{message}: {reference}"
error_code, message = self._construct_error(qualified_name, 3, 2)

self.add_violation(
node,
Expand All @@ -94,24 +103,37 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
replacement=replacement,
)

def visit_Call(self, node) -> None:
def visit_Call(self, node: cst.Call) -> None:
self.in_call = True
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name is None:
return

if qualified_name in self.deprecated_config:
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERRORS[1].error_code
message = self.ERRORS[1].message(old_name=qualified_name)
else:
error_code = self.ERRORS[0].error_code
message = self.ERRORS[0].message(old_name=qualified_name)
replacement = self._call_replacement(node, qualified_name)
if qualified_name not in self.deprecated_config:
return

error_code, message = self._construct_error(qualified_name, 1, 0)
replacement = self._call_replacement(node, qualified_name)
self.add_violation(
node, error_code=error_code, message=message, replacement=replacement
)

reference = self.deprecated_config[qualified_name].get("reference")
if reference is not None:
message = f"{message}: {reference}"
def leave_Call(self, original_node: cst.Call) -> None:
self.in_call = False

self.add_violation(
node, error_code=error_code, message=message, replacement=replacement
)
def visit_Attribute(self, node: cst.Attribute):
# avoid duplicates
if self.in_call:
return False

name_metadata = list(self.get_metadata(QualifiedNameProvider, node))
if not name_metadata:
return False

qualified_name = name_metadata[0].name
if qualified_name not in self.deprecated_config:
return False

error_code, message = self._construct_error(qualified_name, 4, 5)
self.add_violation(node, error_code=error_code, message=message)
return None
Loading