Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: correct ax.legend() order for stack hists #527

Merged
merged 12 commits into from
Oct 18, 2024
15 changes: 15 additions & 0 deletions src/mplhep/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,20 @@ def iterable_not_string(arg):
##########
# Plotting
return_artists: list[StairsArtists | ErrorBarArtists] = []
# customize color cycle assignment when stacking to match legend
if stack:
plottables = plottables[::-1]
_chunked_kwargs = _chunked_kwargs[::-1]
_labels = _labels[::-1]
if "color" not in kwargs:
# Inverse default color cycle
_colors = []
for _ in range(len(plottables)):
_colors.append(ax._get_lines.get_next_color()) # type: ignore[attr-defined]
_colors.reverse()
for i in range(len(plottables)):
_chunked_kwargs[i].update({"color": _colors[i]})

if histtype == "step":
for i in range(len(plottables)):
do_errors = yerr is not False and (
Expand All @@ -419,6 +433,7 @@ def iterable_not_string(arg):
_kwargs = _chunked_kwargs[i]
_label = _labels[i] if do_errors else None
_step_label = _labels[i] if not do_errors else None

_kwargs = soft_update_kwargs(_kwargs, {"linewidth": 1.5})

_plot_info = plottables[i].to_stairs()
Expand Down
Binary file modified tests/baseline/test_histplot_kwargs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline/test_histplot_real.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline/test_histplot_stack.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 8 additions & 3 deletions tests/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def mock_matplotlib(mocker):
ax.plot.return_value = (line2d,)
ax.errorbar.return_value = (line2d,)

# Mock the _get_lines attribute
_get_lines = mocker.Mock()
_get_lines.get_next_color.return_value = "next-color"
ax._get_lines = _get_lines

mpl = mocker.patch("matplotlib.pyplot", autospec=True)
mocker.patch("matplotlib.pyplot.subplots", return_value=(fig, ax))

Expand Down Expand Up @@ -93,7 +98,7 @@ def test_histplot_real(mock_matplotlib):
hep.histplot([c], bins=bins, ax=ax, yerr=True, histtype="errorbar", label="Data")
ax.legend()
ax.set_title("Data/MC")
assert len(ax.mock_calls) == 18
assert len(ax.mock_calls) == 20
ax.reset_mock()

hep.histplot(
Expand All @@ -110,7 +115,7 @@ def test_histplot_real(mock_matplotlib):
)
ax.legend()
ax.set_title("Data/MC binwnorm")
assert len(ax.mock_calls) == 18
assert len(ax.mock_calls) == 20
ax.reset_mock()

hep.histplot(
Expand All @@ -127,4 +132,4 @@ def test_histplot_real(mock_matplotlib):
)
ax.legend()
ax.set_title("Data/MC Density")
assert len(ax.mock_calls) == 18
assert len(ax.mock_calls) == 20
Loading