diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 21a8994..4a1de3c 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -26,7 +26,7 @@ __version__ = "0.4.0" -DEPRECATED_CONFIG_PATH = Path(__file__).absolute().parent / "deprecated_symbols.yaml" +DEPRECATED_CONFIG_PATH = "deprecated_symbols.yaml" DISABLED_BY_DEFAULT = ["TOR3", "TOR4", "TOR9"] diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index 3450777..9949242 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -1,4 +1,5 @@ import libcst as cst +import pkgutil import yaml from typing import Optional from collections.abc import Sequence @@ -21,9 +22,9 @@ def __init__(self, deprecated_config_path=None): def read_deprecated_config(path=None): deprecated_config = {} if path is not None: - with open(path) as f: - for item in yaml.load(f, yaml.SafeLoader): - deprecated_config[item["name"]] = item + data = pkgutil.get_data("torchfix", path) + for item in yaml.load(data, yaml.SafeLoader): + deprecated_config[item["name"]] = item return deprecated_config super().__init__()