Skip to content

Commit

Permalink
Merge pull request #312 from Carreau/lazy-black
Browse files Browse the repository at this point in the history
Make black import lazy.
  • Loading branch information
Carreau authored Mar 6, 2024
2 parents 2b4c0eb + edbc4ed commit a058ba9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 31 deletions.
47 changes: 27 additions & 20 deletions lib/python/pyflyby/_importstmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
from pyflyby._util import (Inf, cached_attribute, cmp,
longest_common_prefix)

from black import format_str, FileMode as Mode
from black.files import find_pyproject_toml, parse_pyproject_toml
from black.mode import TargetVersion

from typing import Dict, Tuple, Optional

def read_black_config():
"""Read the black configuration from ``pyproject.toml``


"""
pyproject_path = find_pyproject_toml('.')
def read_black_config() -> Dict:
"""Read the black configuration from ``pyproject.toml``"""
from black.files import find_pyproject_toml, parse_pyproject_toml

pyproject_path = find_pyproject_toml((".",))

raw_config = parse_pyproject_toml(pyproject_path) if pyproject_path else {}

Expand Down Expand Up @@ -348,24 +347,31 @@ def __lt__(self, other):
return NotImplemented
return self._data < other._data

def _validate_alias(arg):
def _validate_alias(arg) -> Tuple[str, Optional[str]]:
"""
Ensure each alias is a tuple (str, None|str), and return it.
"""
assert isinstance(arg, tuple)
a0,a1 = arg
# Pyright does not seem to be able to infer the length from a
# the unpacking.
assert len(arg) == 2
a0, a1 = arg
assert isinstance(a0, str)
assert isinstance(a1, (str, type(None)))
return arg

@total_ordering
class ImportStatement(object):
class ImportStatement:
"""
Token-level representation of an import statement containing multiple
imports from a single module. Corresponds to an ``ast.ImportFrom`` or
``ast.Import``.
"""

aliases : Tuple[Tuple[str, Optional[str]],...]
fromname : Optional[str]

def __new__(cls, arg):
if isinstance(arg, cls):
return arg
Expand All @@ -381,10 +387,10 @@ def __new__(cls, arg):
raise TypeError

@classmethod
def from_parts(cls, fromname, aliases):
assert isinstance(aliases, list)
def from_parts(cls, fromname:Optional[str], aliases:Tuple[Tuple[str, Optional[str]],...]):
assert isinstance(aliases, tuple)
assert len(aliases) > 0

self = object.__new__(cls)
self.fromname = fromname
self.aliases = tuple(_validate_alias(a) for a in aliases)
Expand Down Expand Up @@ -440,7 +446,7 @@ def _from_ast_node(cls, node):
raise NonImportStatementError(
'Expected ImportStatement, got {node}'.format(node=node)
)
aliases = [ (alias.name, alias.asname) for alias in node.names ]
aliases = tuple( (alias.name, alias.asname) for alias in node.names )
return cls.from_parts(fromname, aliases)

@classmethod
Expand All @@ -463,7 +469,7 @@ def _from_imports(cls, imports):
raise ValueError(
"Inconsistent module names %r" % (sorted(module_names),))
fromname = list(module_names)[0]
aliases = [ imp.split[1:] for imp in imports ]
aliases = tuple(imp.split[1:] for imp in imports)
return cls.from_parts(fromname, aliases)

@cached_attribute
Expand Down Expand Up @@ -529,12 +535,15 @@ def pretty_print(self, params=FormatParams(),
return res

@staticmethod
def run_black(src_contents: str, params) -> str:
def run_black(src_contents: str, params:FormatParams) -> str:
"""Run the black formatter for the Python source code given as a string
This is adapted from https://github.com/akaihola/darker
"""
from black import format_str, FileMode
from black.mode import TargetVersion

black_config = read_black_config()
mode = dict()
if "line_length" in black_config:
Expand Down Expand Up @@ -568,14 +577,12 @@ def run_black(src_contents: str, params) -> str:
# ``--skip-string-normalization``, but the parameter for
# ``black.Mode`` needs to be the opposite boolean of
# ``skip-string-normalization``, hence the inverse boolean
mode["string_normalization"] = not black_config[
"skip_string_normalization"
]
mode["string_normalization"] = not black_config["skip_string_normalization"]

# The custom handling of empty and all-whitespace files below will be unnecessary if
# https://github.com/psf/black/pull/2484 lands in Black.
contents_for_black = src_contents
return format_str(contents_for_black, mode=Mode(**mode))
return format_str(contents_for_black, mode=FileMode(**mode))

@property
def _data(self):
Expand Down
2 changes: 1 addition & 1 deletion lib/python/pyflyby/_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def exports(self):
members = [n for n in members if not n.startswith("_")]

# Filter out artificially added "deep" members.
members = [(n, None) for n in members if "." not in n]
members = tuple([(n, None) for n in members if "." not in n])
if not members:
return None
return ImportSet(
Expand Down
20 changes: 10 additions & 10 deletions tests/test_importstmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,24 +218,24 @@ def test_ImportStatement_eqne_2():
assert not (stmt1a == stmt2 )


@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: None)
@patch("black.files.find_pyproject_toml", lambda root: None)
def test_ImportStatement_pretty_print_black_no_config():
# running should not error out when no pyproject.toml file is found
stmt = ImportStatement("from a import b")
result = stmt.pretty_print(params=FormatParams(use_black=True))
assert isinstance(result, str)


@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: None)
@patch("black.files.find_pyproject_toml", lambda root: None)
def test_read_black_config_no_config():
# reading black config should work when no pyproject.toml file is found
config = read_black_config()
assert config == {}


@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: "pyproject.toml")
@patch("black.files.find_pyproject_toml", lambda root: "pyproject.toml")
@patch(
"pyflyby._importstmt.parse_pyproject_toml",
"black.files.parse_pyproject_toml",
lambda path: {
"line_length": 80,
"skip_magic_trailing_comma": True,
Expand All @@ -253,21 +253,21 @@ def test_read_black_config_extracts_config_subset():
assert "skip_source_first_line" not in config


@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: "pyproject.toml")
@patch("pyflyby._importstmt.parse_pyproject_toml", lambda path: {"target_version": ["py310", "py311"]})
@patch("black.files.find_pyproject_toml", lambda root: "pyproject.toml")
@patch("black.files.parse_pyproject_toml", lambda path: {"target_version": ["py310", "py311"]})
def test_read_black_config_target_version_list():
config = read_black_config()
assert config["target_version"] == {"py310", "py311"}


@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: "pyproject.toml")
@patch("pyflyby._importstmt.parse_pyproject_toml", lambda path: {"target_version": "py311"})
@patch("black.files.find_pyproject_toml", lambda root: "pyproject.toml")
@patch("black.files.parse_pyproject_toml", lambda path: {"target_version": "py311"})
def test_read_black_config_target_version_str():
config = read_black_config()
assert config["target_version"] == "py311"

@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: "pyproject.toml")
@patch("pyflyby._importstmt.parse_pyproject_toml", lambda path: {"target_version": object()})
@patch("black.files.find_pyproject_toml", lambda root: "pyproject.toml")
@patch("black.files.parse_pyproject_toml", lambda path: {"target_version": object()})
def test_read_black_config_target_version_other():
with raises(ValueError, match="Invalid config for black"):
read_black_config()

0 comments on commit a058ba9

Please sign in to comment.