Skip to content

Commit

Permalink
feat: correct ax.legend() order for stack hists (#527)
Browse files Browse the repository at this point in the history
* correct ax.legend() order for stack hists

* cleaner order for stacked hist

* correctly update colors when stack

* replaced ax.get_lines.get_next_color()

* correct color logic

* fix: tests for new legend order and color management

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: use internal mpl function for color

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update src/mplhep/plot.py

* Update src/mplhep/plot.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Andrzej Novak <[email protected]>
  • Loading branch information
3 people authored Oct 18, 2024
1 parent 906d9a2 commit 19f9523
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 3 deletions.
15 changes: 15 additions & 0 deletions src/mplhep/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,20 @@ def iterable_not_string(arg):
##########
# Plotting
return_artists: list[StairsArtists | ErrorBarArtists] = []
# customize color cycle assignment when stacking to match legend
if stack:
plottables = plottables[::-1]
_chunked_kwargs = _chunked_kwargs[::-1]
_labels = _labels[::-1]
if "color" not in kwargs:
# Inverse default color cycle
_colors = []
for _ in range(len(plottables)):
_colors.append(ax._get_lines.get_next_color()) # type: ignore[attr-defined]
_colors.reverse()
for i in range(len(plottables)):
_chunked_kwargs[i].update({"color": _colors[i]})

if histtype == "step":
for i in range(len(plottables)):
do_errors = yerr is not False and (
Expand All @@ -419,6 +433,7 @@ def iterable_not_string(arg):
_kwargs = _chunked_kwargs[i]
_label = _labels[i] if do_errors else None
_step_label = _labels[i] if not do_errors else None

_kwargs = soft_update_kwargs(_kwargs, {"linewidth": 1.5})

_plot_info = plottables[i].to_stairs()
Expand Down
Binary file modified tests/baseline/test_histplot_kwargs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline/test_histplot_real.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline/test_histplot_stack.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 8 additions & 3 deletions tests/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def mock_matplotlib(mocker):
ax.plot.return_value = (line2d,)
ax.errorbar.return_value = (line2d,)

# Mock the _get_lines attribute
_get_lines = mocker.Mock()
_get_lines.get_next_color.return_value = "next-color"
ax._get_lines = _get_lines

mpl = mocker.patch("matplotlib.pyplot", autospec=True)
mocker.patch("matplotlib.pyplot.subplots", return_value=(fig, ax))

Expand Down Expand Up @@ -93,7 +98,7 @@ def test_histplot_real(mock_matplotlib):
hep.histplot([c], bins=bins, ax=ax, yerr=True, histtype="errorbar", label="Data")
ax.legend()
ax.set_title("Data/MC")
assert len(ax.mock_calls) == 18
assert len(ax.mock_calls) == 20
ax.reset_mock()

hep.histplot(
Expand All @@ -110,7 +115,7 @@ def test_histplot_real(mock_matplotlib):
)
ax.legend()
ax.set_title("Data/MC binwnorm")
assert len(ax.mock_calls) == 18
assert len(ax.mock_calls) == 20
ax.reset_mock()

hep.histplot(
Expand All @@ -127,4 +132,4 @@ def test_histplot_real(mock_matplotlib):
)
ax.legend()
ax.set_title("Data/MC Density")
assert len(ax.mock_calls) == 18
assert len(ax.mock_calls) == 20

0 comments on commit 19f9523

Please sign in to comment.