From edbc4ed14e336ef372dbfa36eeb67e6fd73a3351 Mon Sep 17 00:00:00 2001 From: Matthias Bussonnier Date: Fri, 1 Mar 2024 15:19:47 +0100 Subject: [PATCH] Make black import lazy. And add some typing. This should take care of #274 --- lib/python/pyflyby/_importstmt.py | 47 ++++++++++++++++++------------- lib/python/pyflyby/_modules.py | 2 +- tests/test_importstmt.py | 20 ++++++------- 3 files changed, 38 insertions(+), 31 deletions(-) diff --git a/lib/python/pyflyby/_importstmt.py b/lib/python/pyflyby/_importstmt.py index dfbf15d8..dbf62a82 100644 --- a/lib/python/pyflyby/_importstmt.py +++ b/lib/python/pyflyby/_importstmt.py @@ -15,17 +15,16 @@ from pyflyby._util import (Inf, cached_attribute, cmp, longest_common_prefix) -from black import format_str, FileMode as Mode -from black.files import find_pyproject_toml, parse_pyproject_toml -from black.mode import TargetVersion +from typing import Dict, Tuple, Optional -def read_black_config(): - """Read the black configuration from ``pyproject.toml`` - """ - pyproject_path = find_pyproject_toml('.') +def read_black_config() -> Dict: + """Read the black configuration from ``pyproject.toml``""" + from black.files import find_pyproject_toml, parse_pyproject_toml + + pyproject_path = find_pyproject_toml((".",)) raw_config = parse_pyproject_toml(pyproject_path) if pyproject_path else {} @@ -348,24 +347,31 @@ def __lt__(self, other): return NotImplemented return self._data < other._data -def _validate_alias(arg): +def _validate_alias(arg) -> Tuple[str, Optional[str]]: """ Ensure each alias is a tuple (str, None|str), and return it. """ assert isinstance(arg, tuple) - a0,a1 = arg + # Pyright does not seem to be able to infer the length from a + # the unpacking. + assert len(arg) == 2 + a0, a1 = arg assert isinstance(a0, str) assert isinstance(a1, (str, type(None))) return arg @total_ordering -class ImportStatement(object): +class ImportStatement: """ Token-level representation of an import statement containing multiple imports from a single module. Corresponds to an ``ast.ImportFrom`` or ``ast.Import``. """ + + aliases : Tuple[Tuple[str, Optional[str]],...] + fromname : Optional[str] + def __new__(cls, arg): if isinstance(arg, cls): return arg @@ -381,10 +387,10 @@ def __new__(cls, arg): raise TypeError @classmethod - def from_parts(cls, fromname, aliases): - assert isinstance(aliases, list) + def from_parts(cls, fromname:Optional[str], aliases:Tuple[Tuple[str, Optional[str]],...]): + assert isinstance(aliases, tuple) assert len(aliases) > 0 - + self = object.__new__(cls) self.fromname = fromname self.aliases = tuple(_validate_alias(a) for a in aliases) @@ -440,7 +446,7 @@ def _from_ast_node(cls, node): raise NonImportStatementError( 'Expected ImportStatement, got {node}'.format(node=node) ) - aliases = [ (alias.name, alias.asname) for alias in node.names ] + aliases = tuple( (alias.name, alias.asname) for alias in node.names ) return cls.from_parts(fromname, aliases) @classmethod @@ -463,7 +469,7 @@ def _from_imports(cls, imports): raise ValueError( "Inconsistent module names %r" % (sorted(module_names),)) fromname = list(module_names)[0] - aliases = [ imp.split[1:] for imp in imports ] + aliases = tuple(imp.split[1:] for imp in imports) return cls.from_parts(fromname, aliases) @cached_attribute @@ -529,12 +535,15 @@ def pretty_print(self, params=FormatParams(), return res @staticmethod - def run_black(src_contents: str, params) -> str: + def run_black(src_contents: str, params:FormatParams) -> str: """Run the black formatter for the Python source code given as a string This is adapted from https://github.com/akaihola/darker """ + from black import format_str, FileMode + from black.mode import TargetVersion + black_config = read_black_config() mode = dict() if "line_length" in black_config: @@ -568,14 +577,12 @@ def run_black(src_contents: str, params) -> str: # ``--skip-string-normalization``, but the parameter for # ``black.Mode`` needs to be the opposite boolean of # ``skip-string-normalization``, hence the inverse boolean - mode["string_normalization"] = not black_config[ - "skip_string_normalization" - ] + mode["string_normalization"] = not black_config["skip_string_normalization"] # The custom handling of empty and all-whitespace files below will be unnecessary if # https://github.com/psf/black/pull/2484 lands in Black. contents_for_black = src_contents - return format_str(contents_for_black, mode=Mode(**mode)) + return format_str(contents_for_black, mode=FileMode(**mode)) @property def _data(self): diff --git a/lib/python/pyflyby/_modules.py b/lib/python/pyflyby/_modules.py index c11becaa..c2a43788 100644 --- a/lib/python/pyflyby/_modules.py +++ b/lib/python/pyflyby/_modules.py @@ -418,7 +418,7 @@ def exports(self): members = [n for n in members if not n.startswith("_")] # Filter out artificially added "deep" members. - members = [(n, None) for n in members if "." not in n] + members = tuple([(n, None) for n in members if "." not in n]) if not members: return None return ImportSet( diff --git a/tests/test_importstmt.py b/tests/test_importstmt.py index 11fbe5aa..5f177ff6 100644 --- a/tests/test_importstmt.py +++ b/tests/test_importstmt.py @@ -218,7 +218,7 @@ def test_ImportStatement_eqne_2(): assert not (stmt1a == stmt2 ) -@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: None) +@patch("black.files.find_pyproject_toml", lambda root: None) def test_ImportStatement_pretty_print_black_no_config(): # running should not error out when no pyproject.toml file is found stmt = ImportStatement("from a import b") @@ -226,16 +226,16 @@ def test_ImportStatement_pretty_print_black_no_config(): assert isinstance(result, str) -@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: None) +@patch("black.files.find_pyproject_toml", lambda root: None) def test_read_black_config_no_config(): # reading black config should work when no pyproject.toml file is found config = read_black_config() assert config == {} -@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: "pyproject.toml") +@patch("black.files.find_pyproject_toml", lambda root: "pyproject.toml") @patch( - "pyflyby._importstmt.parse_pyproject_toml", + "black.files.parse_pyproject_toml", lambda path: { "line_length": 80, "skip_magic_trailing_comma": True, @@ -253,21 +253,21 @@ def test_read_black_config_extracts_config_subset(): assert "skip_source_first_line" not in config -@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: "pyproject.toml") -@patch("pyflyby._importstmt.parse_pyproject_toml", lambda path: {"target_version": ["py310", "py311"]}) +@patch("black.files.find_pyproject_toml", lambda root: "pyproject.toml") +@patch("black.files.parse_pyproject_toml", lambda path: {"target_version": ["py310", "py311"]}) def test_read_black_config_target_version_list(): config = read_black_config() assert config["target_version"] == {"py310", "py311"} -@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: "pyproject.toml") -@patch("pyflyby._importstmt.parse_pyproject_toml", lambda path: {"target_version": "py311"}) +@patch("black.files.find_pyproject_toml", lambda root: "pyproject.toml") +@patch("black.files.parse_pyproject_toml", lambda path: {"target_version": "py311"}) def test_read_black_config_target_version_str(): config = read_black_config() assert config["target_version"] == "py311" -@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: "pyproject.toml") -@patch("pyflyby._importstmt.parse_pyproject_toml", lambda path: {"target_version": object()}) +@patch("black.files.find_pyproject_toml", lambda root: "pyproject.toml") +@patch("black.files.parse_pyproject_toml", lambda path: {"target_version": object()}) def test_read_black_config_target_version_other(): with raises(ValueError, match="Invalid config for black"): read_black_config()