From 906d9a2cc266d466b3f58ce20b660292b9adca84 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Tue, 8 Oct 2024 00:17:43 -0400 Subject: [PATCH] chore: use more ruff (#524) Signed-off-by: Henry Schreiner --- .pre-commit-config.yaml | 22 +---- docs/source/conf.py | 13 +-- examples/Examples.ipynb | 2 +- pyproject.toml | 58 ++++++++--- src/mplhep/__init__.py | 2 +- src/mplhep/_deprecate.py | 8 +- src/mplhep/_version.pyi | 2 - src/mplhep/alice.py | 4 +- src/mplhep/atlas.py | 4 +- src/mplhep/cms.py | 4 +- src/mplhep/label.py | 56 ++++------- src/mplhep/lhcb.py | 4 +- src/mplhep/plot.py | 180 ++++++++++++++++++---------------- src/mplhep/styles/__init__.py | 1 + src/mplhep/utils.py | 31 +++--- tests/test_basic.py | 26 ++++- tests/test_inputs.py | 2 +- tests/test_layouts.py | 2 +- tests/test_mock.py | 2 +- tests/test_notebooks.py | 2 +- tests/test_styles.py | 8 +- 21 files changed, 226 insertions(+), 207 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4fb60555..e8d62d65 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,31 +16,13 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace -- repo: https://github.com/asottile/setup-cfg-fmt - rev: "v2.5.0" - hooks: - - id: setup-cfg-fmt - args: ["--include-version-classifiers", "--max-py-version=3.12"] - -- repo: https://github.com/nbQA-dev/nbQA - rev: 1.8.7 - hooks: - - id: nbqa-pyupgrade - additional_dependencies: [pyupgrade] - args: ["--py38-plus"] - -- repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.8.0 - hooks: - - id: black-jupyter - - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.6.7 + rev: v0.6.9 hooks: # Run the linter. - id: ruff - args: [ --fix ] + args: [ --fix, --show-fixes ] # Run the formatter. - id: ruff-format diff --git a/docs/source/conf.py b/docs/source/conf.py index ee2b8827..2363f757 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,11 +22,9 @@ # Add mplhep to path for sphinx-automodapi sys.path.insert(0, os.path.abspath("../../src")) -import mplhep # noqa: E402 from pathlib import Path -print("sys.path:", sys.path) -print("mplhep version:", mplhep.__version__) +import mplhep # -- Project information ----------------------------------------------------- @@ -70,9 +68,10 @@ def linkcode_resolve(domain, info): mod = importlib.import_module(info["module"]) modpath = [p for p in sys.path if mod.__file__.startswith(p)] if len(modpath) < 1: - raise RuntimeError("Cannot deduce module path") + msg = "Cannot deduce module path" + raise RuntimeError(msg) modpath = modpath[0] - obj = reduce(getattr, [mod] + info["fullname"].split(".")) + obj = reduce(getattr, [mod, *info["fullname"].split(".")]) try: path = inspect.getsourcefile(obj) relpath = path[len(modpath) + 1 :] @@ -80,9 +79,7 @@ def linkcode_resolve(domain, info): except TypeError: # skip property or other type that inspect doesn't like return None - return "http://github.com/scikit-hep/mplhep/blob/{}/{}#L{}".format( - githash, relpath, lineno - ) + return f"http://github.com/scikit-hep/mplhep/blob/{githash}/{relpath}#L{lineno}" intersphinx_mapping = { diff --git a/examples/Examples.ipynb b/examples/Examples.ipynb index 35c30ef7..b23e4be4 100644 --- a/examples/Examples.ipynb +++ b/examples/Examples.ipynb @@ -11,8 +11,8 @@ }, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "import mplhep as hep" ] diff --git a/pyproject.toml b/pyproject.toml index 8b838232..040b05df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,9 +36,7 @@ dependencies = [ [project.optional-dependencies] dev = [ - "black", "bumpversion", - "flake8", "jupyter", "pre-commit", "twine", @@ -67,14 +65,16 @@ version.source = "vcs" build.hooks.vcs.version-file = "src/mplhep/_version.py" -[tool.nbqa.mutate] -pyupgrade = 1 - -[tool.nbqa.addopts] -pyupgrade = ["--py38-plus"] +[tool.uv] +environments = [ + "python_version >= '3.10'", +] +dev-dependencies = [ + "mplhep[test]", +] -[tools.mypy] +[tool.mypy] files = ["src"] python_version = 3.8 warn_unused_configs = true @@ -85,7 +85,7 @@ allow_redefinition = true # disallow_untyped_calls = true # disallow_untyped_defs = true # disallow_incomplete_defs = true -check_untyped_defs = true +check_untyped_defs = false # disallow_untyped_decorators = true # no_implicit_optional = true # warn_redundant_casts = true @@ -106,5 +106,41 @@ testpaths = ["tests"] python_files = "test*.py" -[tool.isort] -profile = "black" +[tool.ruff.lint] +extend-select = [ + "B", # flake8-bugbear + "I", # isort + "ARG", # flake8-unused-arguments + "C4", # flake8-comprehensions + "EM", # flake8-errmsg + "ICN", # flake8-import-conventions + "G", # flake8-logging-format + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "RET", # flake8-return + "RUF", # Ruff-specific + "SIM", # flake8-simplify + "T20", # flake8-print + "UP", # pyupgrade + "YTT", # flake8-2020 + "EXE", # flake8-executable + "NPY", # NumPy specific rules + "PD", # pandas-vet + "FURB", # refurb + "PYI", # flake8-pyi +] +ignore = [ + "PLR09", # Too many <...> + "PLR2004", # Magic value used in comparison + "ISC001", # Conflicts with formatter + "NPY002", # np.random.Generator + "G004", # Logging with f-string + "PD011", # .values vs .to_numpy can get confused + "PT013", # Import from pytest +] + +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["T20"] +"docs/conf.py" = ["T20"] diff --git a/src/mplhep/__init__.py b/src/mplhep/__init__.py index 7768aca4..948671b6 100644 --- a/src/mplhep/__init__.py +++ b/src/mplhep/__init__.py @@ -23,8 +23,8 @@ rescale_to_axessize, sort_legend, ylow, - yscale_legend, yscale_anchored_text, + yscale_legend, ) from .styles import set_style diff --git a/src/mplhep/_deprecate.py b/src/mplhep/_deprecate.py index 61544476..5562f854 100644 --- a/src/mplhep/_deprecate.py +++ b/src/mplhep/_deprecate.py @@ -48,13 +48,9 @@ def __init__(self, name, reason="", warn_once: bool = True, warning=FutureWarnin def __call__(self, func): def decorated_func(*args, **kwargs): - if self._name in kwargs.keys() and not ( - self._warn_once and self._already_warned - ): + if self._name in kwargs and not (self._warn_once and self._already_warned): warnings.warn( - 'kwarg "{}" in function ``{}`` is deprecated and may be removed in future versions: {}'.format( - self._name, func.__name__, self._reason - ), + f'kwarg "{self._name}" in function ``{func.__name__}`` is deprecated and may be removed in future versions: {self._reason}', category=self._warning, stacklevel=2, ) diff --git a/src/mplhep/_version.pyi b/src/mplhep/_version.pyi index 91744f98..5bb2b22f 100644 --- a/src/mplhep/_version.pyi +++ b/src/mplhep/_version.pyi @@ -1,4 +1,2 @@ -from __future__ import annotations - version: str version_tuple: tuple[int, int, int] | tuple[int, int, int, str, str] diff --git a/src/mplhep/alice.py b/src/mplhep/alice.py index 4892a66a..912db6c1 100644 --- a/src/mplhep/alice.py +++ b/src/mplhep/alice.py @@ -19,7 +19,7 @@ def text(text="", **kwargs): for key, value in dict(mplhep.rcParams.text._get_kwargs()).items(): if ( value is not None - and key not in kwargs.keys() + and key not in kwargs and key in inspect.getfullargspec(label_base.exp_text).kwonlyargs ): kwargs.setdefault(key, value) @@ -31,7 +31,7 @@ def label(label=None, **kwargs): for key, value in dict(mplhep.rcParams.label._get_kwargs()).items(): if ( value is not None - and key not in kwargs.keys() + and key not in kwargs and key in inspect.getfullargspec(label_base.exp_label).kwonlyargs ): kwargs.setdefault(key, value) diff --git a/src/mplhep/atlas.py b/src/mplhep/atlas.py index 15b40057..6a1f8b86 100644 --- a/src/mplhep/atlas.py +++ b/src/mplhep/atlas.py @@ -21,7 +21,7 @@ def text(text="", **kwargs): for key, value in dict(mplhep.rcParams.text._get_kwargs()).items(): if ( value is not None - and key not in kwargs.keys() + and key not in kwargs and key in inspect.getfullargspec(label_base.exp_text).kwonlyargs ): kwargs.setdefault(key, value) @@ -35,7 +35,7 @@ def label(label=None, **kwargs): for key, value in dict(mplhep.rcParams.label._get_kwargs()).items(): if ( value is not None - and key not in kwargs.keys() + and key not in kwargs and key in inspect.getfullargspec(label_base.exp_label).kwonlyargs ): kwargs.setdefault(key, value) diff --git a/src/mplhep/cms.py b/src/mplhep/cms.py index 8d9d54d0..acd8cf3b 100644 --- a/src/mplhep/cms.py +++ b/src/mplhep/cms.py @@ -21,7 +21,7 @@ def text(text="", **kwargs): for key, value in dict(mplhep.rcParams.text._get_kwargs()).items(): if ( value is not None - and key not in kwargs.keys() + and key not in kwargs and key in inspect.getfullargspec(label_base.exp_text).kwonlyargs ): kwargs.setdefault(key, value) @@ -35,7 +35,7 @@ def label(label=None, **kwargs): for key, value in dict(mplhep.rcParams.label._get_kwargs()).items(): if ( value is not None - and key not in kwargs.keys() + and key not in kwargs and key in inspect.getfullargspec(label_base.exp_label).kwonlyargs ): kwargs.setdefault(key, value) diff --git a/src/mplhep/label.py b/src/mplhep/label.py index adcfe77a..c1b21e8e 100644 --- a/src/mplhep/label.py +++ b/src/mplhep/label.py @@ -11,25 +11,17 @@ class ExpText(mtext.Text): def __repr__(self): - return "exptext: Custom Text({}, {}, {})".format( - self._x, self._y, repr(self._text) - ) + return f"exptext: Custom Text({self._x}, {self._y}, {self._text!r})" class ExpSuffix(mtext.Text): def __repr__(self): - return "expsuffix: Custom Text({}, {}, {})".format( - self._x, - self._y, - repr(self._text), - ) + return f"expsuffix: Custom Text({self._x}, {self._y}, {self._text!r})" class SuppText(mtext.Text): def __repr__(self): - return "supptext: Custom Text({}, {}, {})".format( - self._x, self._y, repr(self._text) - ) + return f"supptext: Custom Text({self._x}, {self._y}, {self._text!r})" def exp_text( @@ -105,13 +97,14 @@ def exp_text( } if loc not in [0, 1, 2, 3, 4]: - raise ValueError( + msg = ( "loc must be in {0, 1, 2}:\n" "0 : Above axes, left aligned\n" "1 : Top left corner\n" "2 : Top left corner, multiline\n" "3 : Split EXP above axes, rest of label in top left corner\n" ) + raise ValueError(msg) def pixel_to_axis(extent, ax=None): # Transform pixel bbox extends to axis fractions @@ -134,10 +127,7 @@ def dist(tup): abs(y1 - y) / dimy, ) - if loc in [0, 3]: - _exp_loc = 0 - else: - _exp_loc = 1 + _exp_loc = 0 if loc in [0, 3] else 1 _formater = ax.get_yaxis().get_major_formatter() if isinstance(_formater, mpl.ticker.ScalarFormatter) and _exp_loc == 0: _sci_box = pixel_to_axis( @@ -209,14 +199,7 @@ def dist(tup): units="inches", fig=ax.figure, ) - elif loc == 2: - _t = mtransforms.offset_copy( - expsuffix._transform, - y=-expsuffix.get_window_extent().height / _dpi, - units="inches", - fig=ax.figure, - ) - elif loc == 3: + elif loc in (2, 3): _t = mtransforms.offset_copy( expsuffix._transform, y=-expsuffix.get_window_extent().height / _dpi, @@ -368,18 +351,17 @@ def exp_label( # Right label if rlabel is not None: _lumi = rlabel + elif lumi is not None: + _lumi = r"{lumi}{year} ({com} TeV)".format( + lumi=lumi_format.format(lumi) + r" $\mathrm{fb^{-1}}$", + year=", " + str(year) if year is not None else "", + com=str(com) if com is not None else "13", + ) else: - if lumi is not None: - _lumi = r"{lumi}{year} ({com} TeV)".format( - lumi=lumi_format.format(lumi) + r" $\mathrm{fb^{-1}}$", - year=", " + str(year) if year is not None else "", - com=str(com) if com is not None else "13", - ) - else: - _lumi = "{year} ({com} TeV)".format( - year=str(year) if year is not None else "", - com=str(com) if com is not None else "13", - ) + _lumi = "{year} ({com} TeV)".format( + year=str(year) if year is not None else "", + com=str(com) if com is not None else "13", + ) if loc < 4: lumitext(text=_lumi, ax=ax, fontname=fontname, fontsize=fontsize) @@ -524,7 +506,7 @@ def savelabels( if ax is None: ax = plt.gca() - label_base = [ch for ch in ax.get_children() if isinstance(ch, ExpSuffix)][0] + label_base = next(ch for ch in ax.get_children() if isinstance(ch, ExpSuffix)) _sim = "Simulation" if "Simulation" in label_base.get_text() else "" for label_text, suffix in labels: @@ -534,7 +516,7 @@ def savelabels( save_name = suffix else: if len(suffix) > 0: - suffix = "_" + suffix + suffix = "_" + suffix # noqa: PLW2901 if "." in fname: save_name = f"{fname.split('.')[0]}{suffix}.{fname.split('.')[1]}" else: diff --git a/src/mplhep/lhcb.py b/src/mplhep/lhcb.py index e8bb2cae..191f1242 100644 --- a/src/mplhep/lhcb.py +++ b/src/mplhep/lhcb.py @@ -36,7 +36,7 @@ def text(text="", **kwargs): for key, value in dict(mplhep.rcParams.text._get_kwargs()).items(): if ( value is not None - and key not in kwargs.keys() + and key not in kwargs and key in inspect.getfullargspec(label_base.exp_text).kwonlyargs ): kwargs.setdefault(key, value) @@ -53,7 +53,7 @@ def label(label=None, **kwargs): for key, value in dict(mplhep.rcParams.label._get_kwargs()).items(): if ( value is not None - and key not in kwargs.keys() + and key not in kwargs and key in inspect.getfullargspec(label_base.exp_label).kwonlyargs ): kwargs.setdefault(key, value) diff --git a/src/mplhep/plot.py b/src/mplhep/plot.py index 8ae57be2..7f8e215e 100644 --- a/src/mplhep/plot.py +++ b/src/mplhep/plot.py @@ -3,8 +3,8 @@ import collections.abc import inspect import logging -from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Any, Union +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, NamedTuple, Union import matplotlib as mpl import matplotlib.pyplot as plt @@ -15,21 +15,34 @@ from .utils import ( Plottable, + align_marker, get_histogram_axes_title, get_plottable_protocol_bins, hist_object_handler, isLight, process_histogram_parts, - align_marker, to_padded2d, ) if TYPE_CHECKING: from numpy.typing import ArrayLike -StairsArtists = namedtuple("StairsArtists", "stairs errorbar legend_artist") -ErrorBarArtists = namedtuple("ErrorBarArtists", "errorbar") -ColormeshArtists = namedtuple("ColormeshArtists", "pcolormesh cbar text") + +class StairsArtists(NamedTuple): + stairs: Any + errorbar: Any + legend_artist: Any + + +class ErrorBarArtists(NamedTuple): + errorbar: Any + + +class ColormeshArtists(NamedTuple): + pcolormesh: Any + cbar: Any + text: Any + Hist1DArtists = Union[StairsArtists, ErrorBarArtists] Hist2DArtists = ColormeshArtists @@ -44,7 +57,7 @@ def soft_update_kwargs(kwargs, mods, rc=True): "lines.linestyle", ] aliases = {"ls": "linestyle", "lw": "linewidth"} - kwargs = {aliases[k] if k in aliases else k: v for k, v in kwargs.items()} + kwargs = {aliases.get(k, k): v for k, v in kwargs.items()} for key, val in mods.items(): rc_modded = (key in not_default) or ( key in [k.split(".")[-1] for k in not_default if k in respect] @@ -150,9 +163,9 @@ def histplot( # ax check if ax is None: ax = plt.gca() - else: - if not isinstance(ax, plt.Axes): - raise ValueError("ax must be a matplotlib Axes object") + elif not isinstance(ax, plt.Axes): + msg = "ax must be a matplotlib Axes object" + raise ValueError(msg) # arg check _allowed_histtype = ["fill", "step", "errorbar", "band"] @@ -186,7 +199,7 @@ def histplot( plottables = [] flow_bins = final_bins - for i, h in enumerate(hists): + for h in hists: value, variance = np.copy(h.values()), h.variances() if has_variances := variance is not None: variance = np.copy(variance) @@ -223,12 +236,10 @@ def histplot( ) # Set plottables - if flow == "none": - plottables.append(Plottable(value, edges=final_bins, variances=variance)) - elif flow == "hint": # modify plottable + if flow in ("none", "hint"): plottables.append(Plottable(value, edges=final_bins, variances=variance)) elif flow == "show": - _flow_bin_size = np.max( + _flow_bin_size: float = np.max( [0.05 * (final_bins[-1] - final_bins[0]), np.mean(np.diff(final_bins))] ) flow_bins = np.copy(final_bins) @@ -264,7 +275,8 @@ def histplot( _plottable.method = w2method if w2 is not None and yerr is not None: - raise ValueError("Can only supply errors or w2") + msg = "Can only supply errors or w2" + raise ValueError(msg) _labels: list[str | None] if label is None: @@ -287,7 +299,7 @@ def iterable_not_string(arg): if iterable_not_string(kwargs[kwarg]): # Check if tuple of floats or ints (can be used for colors) if isinstance(kwargs[kwarg], tuple) and all( - isinstance(x, int) or isinstance(x, float) for x in kwargs[kwarg] + isinstance(x, (int, float)) for x in kwargs[kwarg] ): for i in range(len(_chunked_kwargs)): _chunked_kwargs[i][kwarg] = kwargs[kwarg] @@ -326,7 +338,8 @@ def iterable_not_string(arg): elif _yerr.shape[-2] == 1: # [[1,1]] _yerr = np.tile(_yerr, 2).reshape(len(plottables), 2, _yerr.shape[-1]) else: - raise ValueError("yerr format is not understood") + msg = "yerr format is not understood" + raise ValueError(msg) elif _yerr.ndim == 2: # Broadcast yerr (nh, N) to (nh, 2, N) _yerr = np.tile(_yerr, 2).reshape(len(plottables), 2, _yerr.shape[-1]) @@ -336,7 +349,8 @@ def iterable_not_string(arg): len(plottables), 2, _yerr.shape[-1] ) else: - raise ValueError("yerr format is not understood") + msg = "yerr format is not understood" + raise ValueError(msg) assert _yerr is not None for yrs, _plottable in zip(_yerr, plottables): @@ -348,18 +362,18 @@ def iterable_not_string(arg): if sort.split("_")[0] in ["l", "label"] and isinstance(_labels, list): order = np.argsort(label) # [::-1] elif sort.split("_")[0] in ["y", "yield"]: - _yields = [np.sum(_h.values) for _h in plottables] + _yields = [np.sum(_h.values) for _h in plottables] # type: ignore[var-annotated] order = np.argsort(_yields) if len(sort.split("_")) == 2 and sort.split("_")[1] == "r": order = order[::-1] - elif isinstance(sort, list) or isinstance(sort, np.ndarray): + elif isinstance(sort, (list, np.ndarray)): if len(sort) != len(plottables): - raise ValueError( - f"Sort indexing array is of the wrong size - {len(sort)}, {len(plottables)} expected." - ) + msg = f"Sort indexing array is of the wrong size - {len(sort)}, {len(plottables)} expected." + raise ValueError(msg) order = np.asarray(sort) else: - raise ValueError(f"Sort type: {sort} not understood.") + msg = f"Sort type: {sort} not understood." + raise ValueError(msg) plottables = [plottables[ix] for ix in order] _chunked_kwargs = [_chunked_kwargs[ix] for ix in order] _labels = [_labels[ix] for ix in order] @@ -367,7 +381,8 @@ def iterable_not_string(arg): # ############################ # # Stacking, norming, density if density is True and binwnorm is not None: - raise ValueError("Can only set density or binwnorm.") + msg = "Can only set density or binwnorm." + raise ValueError(msg) if density is True: if stack: _total = np.sum( @@ -501,9 +516,8 @@ def iterable_not_string(arg): _artist = _e[0] # Add sticky edges for autoscale - assert hasattr( - listy := _artist.sticky_edges.y, "append" - ), "cannot append to sticky edges" + listy = _artist.sticky_edges.y + assert hasattr(listy, "append"), "cannot append to sticky edges" listy.append(0) if xtick_labels is None or flow == "show": @@ -519,7 +533,8 @@ def iterable_not_string(arg): # Flow extra styling if (fig := ax.figure) is None: - raise ValueError("No figure found") + msg = "No figure found" + raise ValueError(msg) if flow == "hint": _marker_size = ( 30 @@ -572,7 +587,7 @@ def iterable_not_string(arg): xticks = _xticks xticklabels = _xticklabels break - elif len(_xticklabels) > 0: + if len(_xticklabels) > 0: xticks = _xticks xticklabels = _xticklabels @@ -741,9 +756,9 @@ def hist2dplot( # ax check if ax is None: ax = plt.gca() - else: - if not isinstance(ax, plt.Axes): - raise ValueError("ax must be a matplotlib Axes object") + elif not isinstance(ax, plt.Axes): + msg = "ax must be a matplotlib Axes object" + raise ValueError(msg) h = hist_object_handler(H, xbins, ybins) @@ -759,9 +774,6 @@ def hist2dplot( and "flow" not in inspect.getfullargspec(h.values).args and flow is not None ): - print( - f"Warning: {type(h)} is not allowed to get flow bins, flow bin option set to None" - ) flow = None elif flow in ["hint", "show"]: xwidth, ywidth = (xbins[-1] - xbins[0]) * 0.05, (ybins[-1] - ybins[0]) * 0.05 @@ -809,10 +821,11 @@ def hist2dplot( ) except TypeError as error: if "got an unexpected keyword argument 'flow'" in str(error): - raise TypeError( - f"The histograms value method {repr(h)} does not take a 'flow' argument. UHI Plottable doesn't require this to have, but it is required for this function." + msg = ( + f"The histograms value method {h!r} does not take a 'flow' argument. UHI Plottable doesn't require this to have, but it is required for this function." f" Implementations like hist/boost-histogram support this argument." - ) from error + ) + raise TypeError(msg) from error xbin_centers = xbins[1:] - np.diff(xbins) / float(2) ybin_centers = ybins[1:] - np.diff(ybins) / float(2) @@ -828,9 +841,9 @@ def hist2dplot( H = H.T if cmin is not None: - H[H < cmin] = None + H[cmin > H] = None if cmax is not None: - H[H > cmax] = None + H[cmax < H] = None X, Y = np.meshgrid(xbins, ybins) @@ -842,8 +855,8 @@ def hist2dplot( if y_axes_label: ax.set_ylabel(y_axes_label) - ax.set_xlim(xbins[0], xbins[-1]) - ax.set_ylim(ybins[0], ybins[-1]) + ax.set_xlim(xbins[0], xbins[-1]) # type: ignore[arg-type] + ax.set_ylim(ybins[0], ybins[-1]) # type: ignore[arg-type] if xtick_labels is None: # Ordered axis if len(ax.get_xticks()) > len(xbins) * 0.7: @@ -906,7 +919,8 @@ def hist2dplot( ) elif flow == "hint": if (fig := ax.figure) is None: - raise ValueError("No figure found.") + msg = "No figure found." + raise ValueError(msg) _marker_size = ( 30 * ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()).width @@ -968,25 +982,29 @@ def hist2dplot( if H.shape == label_array.shape: _labels = label_array else: - raise ValueError( - f"Labels input has incorrect shape (expect: {H.shape}, got: {label_array.shape})" - ) + msg = f"Labels input has incorrect shape (expect: {H.shape}, got: {label_array.shape})" + raise ValueError(msg) elif labels is not None: - raise ValueError( - "Labels not understood, either specify a bool or a Hist-like array" - ) + msg = "Labels not understood, either specify a bool or a Hist-like array" + raise ValueError(msg) text_artists = [] if _labels is not None: if (pccmap := pc.cmap) is None: - raise ValueError("No colormap found.") + msg = "No colormap found." + raise ValueError(msg) for ix, xc in enumerate(xbin_centers): for iy, yc in enumerate(ybin_centers): normedh = pc.norm(H[iy, ix]) color = "black" if isLight(pccmap(normedh)[:-1]) else "lightgrey" text_artists.append( ax.text( - xc, yc, _labels[iy, ix], ha="center", va="center", color=color + xc, + yc, + _labels[iy, ix], # type: ignore[arg-type] + ha="center", + va="center", + color=color, ) ) @@ -1038,8 +1056,7 @@ def overlap(ax, bbox, get_vertices=False): if get_vertices: return overlap, vertices - else: - return overlap + return overlap def _draw_leg_bbox(ax): @@ -1049,9 +1066,9 @@ def _draw_leg_bbox(ax): fig = ax.figure leg = ax.get_legend() if leg is None: - leg = [ + leg = next( c for c in ax.get_children() if isinstance(c, plt.matplotlib.legend.Legend) - ][0] + ) fig.canvas.draw() return leg.get_frame().get_bbox() @@ -1077,7 +1094,7 @@ def _draw_text_bbox(ax): def yscale_legend( ax: mpl.axes.Axes | None = None, - otol: float | int | None = None, + otol: float | None = None, soft_fail: bool = False, ) -> mpl.axes.Axes: """ @@ -1110,23 +1127,22 @@ def yscale_legend( logging.info("Scaling y-axis by 5% to fit legend") ax.set_ylim(ax.get_ylim()[0], ax.get_ylim()[-1] * scale_factor) if (fig := ax.figure) is None: - raise RuntimeError("Could not fetch figure, maybe no plot is drawn yet?") + msg = "Could not fetch figure, maybe no plot is drawn yet?" + raise RuntimeError(msg) fig.canvas.draw() if max_scales > 10: if not soft_fail: - raise RuntimeError( - "Could not fit legend in 10 iterations, return anyway by passing `soft_fail=True`." - ) - else: - logging.warning("Could not fit legend in 10 iterations") - break + msg = "Could not fit legend in 10 iterations, return anyway by passing `soft_fail=True`." + raise RuntimeError(msg) + logging.warning("Could not fit legend in 10 iterations") + break max_scales += 1 return ax def yscale_anchored_text( ax: mpl.axes.Axes | None = None, - otol: float | int | None = None, + otol: float | None = None, soft_fail: bool = False, ) -> mpl.axes.Axes: """ @@ -1159,16 +1175,15 @@ def yscale_anchored_text( logging.info("Scaling y-axis by 5% to fit legend") ax.set_ylim(ax.get_ylim()[0], ax.get_ylim()[-1] * scale_factor) if (fig := ax.figure) is None: - raise RuntimeError("Could not fetch figure, maybe no plot is drawn yet?") + msg = "Could not fetch figure, maybe no plot is drawn yet?" + raise RuntimeError(msg) fig.canvas.draw() if max_scales > 10: if not soft_fail: - raise RuntimeError( - "Could not fit AnchoredText in 10 iterations, return anyway by passing `soft_fail=True`." - ) - else: - logging.warning("Could not fit AnchoredText in 10 iterations") - break + msg = "Could not fit AnchoredText in 10 iterations, return anyway by passing `soft_fail=True`." + raise RuntimeError(msg) + logging.warning("Could not fit AnchoredText in 10 iterations") + break max_scales += 1 return ax @@ -1220,13 +1235,11 @@ def mpl_magic(ax=None, info=True): if ax is None: ax = plt.gca() if info: - print("Running ROOT/CMS style adjustments (hide with info=False):") + pass ax = ylow(ax) ax = yscale_legend(ax) - ax = yscale_anchored_text(ax) - - return ax + return yscale_anchored_text(ax) ######################################## @@ -1295,8 +1308,8 @@ def make_square_add_cbar(ax, size=0.4, pad=0.1): cax = divider.append_axes("right", size=margin_size, pad=pad_size) - divider.set_horizontal([RemainderFixed(xsizes, ysizes, divider)] + xsizes) - divider.set_vertical([RemainderFixed(xsizes, ysizes, divider)] + ysizes) + divider.set_horizontal([RemainderFixed(xsizes, ysizes, divider), *xsizes]) + divider.set_vertical([RemainderFixed(xsizes, ysizes, divider), *ysizes]) return cax @@ -1326,7 +1339,7 @@ def convert(fraction, position=position): pad_size = axes_size.Fixed(pad) xsizes = [pad_size, margin_size] if position in ["top", "bottom"]: - xsizes = xsizes[::-1] + xsizes.reverse() yhax = divider.append_axes(position, size=margin_size, pad=pad_size) if extend: @@ -1338,7 +1351,7 @@ def extend_ratio(ax, yhax): return new_size / orig_size if position in ["right"]: - divider.set_horizontal([axes_size.Fixed(width)] + xsizes) + divider.set_horizontal([axes_size.Fixed(width), *xsizes]) fig.set_size_inches( fig.get_size_inches()[0] * extend_ratio(ax, yhax)[0], fig.get_size_inches()[1], @@ -1357,7 +1370,7 @@ def extend_ratio(ax, yhax): ) ax.get_shared_x_axes().join(ax, yhax) elif position in ["bottom"]: - divider.set_vertical(xsizes + [axes_size.Fixed(height)]) + divider.set_vertical([*xsizes, axes_size.Fixed(height)]) fig.set_size_inches( fig.get_size_inches()[0], fig.get_size_inches()[1] * extend_ratio(ax, yhax)[1], @@ -1401,7 +1414,8 @@ def sort_legend(ax, order=None): elif order is None: ordered_label_list = labels else: - raise TypeError(f"Unexpected values type of order: {type(order)}") + msg = f"Unexpected values type of order: {type(order)}" + raise TypeError(msg) ordered_label_list = [entry for entry in ordered_label_list if entry in labels] ordered_label_values = [by_label[k] for k in ordered_label_list] diff --git a/src/mplhep/styles/__init__.py b/src/mplhep/styles/__init__.py index 7d8c27a2..01c14aa9 100644 --- a/src/mplhep/styles/__init__.py +++ b/src/mplhep/styles/__init__.py @@ -72,6 +72,7 @@ def use(styles=None): ] plt_style.use(styles) + return None fira = {"font.sans-serif": "Fira Sans"} diff --git a/src/mplhep/utils.py b/src/mplhep/utils.py index f9355454..68468d33 100644 --- a/src/mplhep/utils.py +++ b/src/mplhep/utils.py @@ -6,10 +6,10 @@ from typing import TYPE_CHECKING, Any, Iterable, Sequence import numpy as np -from uhi.numpy_plottable import ensure_plottable_histogram -from uhi.typing.plottable import PlottableAxis, PlottableHistogram from matplotlib import markers from matplotlib.path import Path +from uhi.numpy_plottable import ensure_plottable_histogram +from uhi.typing.plottable import PlottableAxis, PlottableHistogram if TYPE_CHECKING: from numpy.typing import ArrayLike @@ -52,12 +52,14 @@ def hist_object_handler( hist = (hist, None) hist_obj = ensure_plottable_histogram(hist) elif isinstance(hist, PlottableHistogram): - raise TypeError("Cannot give bins with existing histogram") + msg = "Cannot give bins with existing histogram" + raise TypeError(msg) else: hist_obj = ensure_plottable_histogram((hist, *bins)) if len(hist_obj.axes) not in {1, 2}: - raise ValueError("Must have only 1 or 2 axes") + msg = "Must have only 1 or 2 axes" + raise ValueError(msg) return hist_obj @@ -89,12 +91,9 @@ def process_histogram_parts( """ # Try to understand input - if (isinstance(H, list) or isinstance(H, np.ndarray)) and not isinstance( - H[0], (Real) - ): + if (isinstance(H, (list, np.ndarray))) and not isinstance(H[0], (Real)): return _process_histogram_parts_iter(H, *bins) - else: - return _process_histogram_parts_iter((H,), *bins) # type: ignore[arg-type] + return _process_histogram_parts_iter((H,), *bins) # type: ignore[arg-type] def _process_histogram_parts_iter( @@ -127,9 +126,9 @@ def get_histogram_axes_title(axis: Any) -> str: if hasattr(axis, "label"): return axis.label # Classic support for older hist, deprecated - elif hasattr(axis, "title"): + if hasattr(axis, "title"): return axis.title - elif hasattr(axis, "name"): + if hasattr(axis, "name"): return axis.name # No axis title found @@ -218,9 +217,8 @@ def calculate_relative(method_fcn, variances): elif callable(method): self.yerr_lo, self.yerr_hi = calculate_relative(method, variances) else: - raise RuntimeError( - "``method'' needs to be a callable or 'poisson' or 'sqrt'." - ) + msg = "``method'' needs to be a callable or 'poisson' or 'sqrt'." + raise RuntimeError(msg) self.yerr_lo = np.nan_to_num(self.yerr_lo, 0) self.yerr_hi = np.nan_to_num(self.yerr_hi, 0) @@ -388,7 +386,7 @@ def to_padded2d(h, variances=False): variances_flow = h.variances(flow=True) xpadlo, xpadhi = 1 - h.axes[0].traits.underflow, 1 - h.axes[0].traits.overflow ypadlo, ypadhi = 1 - h.axes[1].traits.underflow, 1 - h.axes[1].traits.overflow - xpadhi_m, mypadhi_m = [-pad if pad != 0 else None for pad in [xpadhi, ypadhi]] + xpadhi_m, mypadhi_m = (-pad if pad != 0 else None for pad in [xpadhi, ypadhi]) padded = np.zeros( ( @@ -401,5 +399,4 @@ def to_padded2d(h, variances=False): padded_varis[xpadlo:xpadhi_m, ypadlo:mypadhi_m] = variances_flow if variances: return padded, padded_varis - else: - return padded + return padded diff --git a/tests/test_basic.py b/tests/test_basic.py index 7ec98be0..2ff56d6f 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import re import hist import matplotlib.pyplot as plt @@ -9,7 +10,14 @@ os.environ["RUNNING_PYTEST"] = "true" -import mplhep as hep # noqa: E402 +try: + # NumPy 2 + from numpy.char import chararray +except ModuleNotFoundError: + # NumPy 1 + from numpy import chararray + +import mplhep as hep """ To test run: @@ -463,18 +471,26 @@ def test_hist2dplot_labels_option(): assert hep.hist2dplot(H, xedges, yedges, labels=False) - label_array = np.chararray(H.shape, itemsize=2) + label_array = chararray(H.shape, itemsize=2) label_array[:] = "hi" assert hep.hist2dplot(H, xedges, yedges, labels=label_array) - label_array = np.chararray(H.shape[0], itemsize=2) + label_array = chararray(H.shape[0], itemsize=2) label_array[:] = "hi" # Label array shape invalid - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=re.escape("Labels input has incorrect shape (expect: (5, 7), got: (7,))"), + ): hep.hist2dplot(H, xedges, yedges, labels=label_array) # Invalid label type - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=re.escape( + "Labels not understood, either specify a bool or a Hist-like array" + ), + ): hep.hist2dplot(H, xedges, yedges, labels=5) diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 6bcc59a7..d44dd988 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -9,7 +9,7 @@ os.environ["RUNNING_PYTEST"] = "true" -import mplhep as hep # noqa: E402 +import mplhep as hep """ To test run: diff --git a/tests/test_layouts.py b/tests/test_layouts.py index 73b17aac..11ad4e46 100644 --- a/tests/test_layouts.py +++ b/tests/test_layouts.py @@ -7,7 +7,7 @@ os.environ["RUNNING_PYTEST"] = "true" -import mplhep as hep # noqa: E402 +import mplhep as hep """ To test run: diff --git a/tests/test_mock.py b/tests/test_mock.py index 98a48579..28f9597a 100644 --- a/tests/test_mock.py +++ b/tests/test_mock.py @@ -3,7 +3,7 @@ from types import SimpleNamespace import matplotlib.lines -import matplotlib.pyplot +import matplotlib.pyplot # noqa: ICN001 import numpy as np import pytest from pytest import approx diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index baa3d353..72962397 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -9,7 +9,7 @@ os.environ["RUNNING_PYTEST"] = "true" -@pytest.fixture() +@pytest.fixture def common_kwargs(tmpdir): outputnb = tmpdir.join("output.ipynb") return { diff --git a/tests/test_styles.py b/tests/test_styles.py index 3df81580..9da26e80 100644 --- a/tests/test_styles.py +++ b/tests/test_styles.py @@ -9,7 +9,7 @@ os.environ["RUNNING_PYTEST"] = "true" -import mplhep as hep # noqa: E402 +import mplhep as hep """ To test run: @@ -126,7 +126,7 @@ def test_use_style_LHCb_dep(fig_test, fig_ref): @pytest.mark.skipif(sys.platform != "linux", reason="Linux only") @check_figures_equal(extensions=["pdf"]) @pytest.mark.parametrize( - "mplhep_style, str_alias", + ("mplhep_style", "str_alias"), [ (hep.style.ALICE, "ALICE"), (hep.style.ATLAS, "ATLAS"), @@ -153,7 +153,7 @@ def test_use_style_str_alias(fig_test, fig_ref, mplhep_style, str_alias): @pytest.mark.skipif(sys.platform != "linux", reason="Linux only") @check_figures_equal(extensions=["pdf"]) @pytest.mark.parametrize( - "mplhep_style, str_alias", + ("mplhep_style", "str_alias"), [ (hep.style.ALICE, "ALICE"), (hep.style.ATLAS, "ATLAS"), @@ -180,7 +180,7 @@ def test_use_style_self_consistent(fig_test, fig_ref, mplhep_style, str_alias): @pytest.mark.skipif(sys.platform != "linux", reason="Linux only") @check_figures_equal(extensions=["pdf"]) @pytest.mark.parametrize( - "mplhep_style, str_alias", + ("mplhep_style", "str_alias"), [ (hep.style.ALICE, "ALICE"), (hep.style.ATLAS, "ATLAS"),