From bf7f7afbb6f5eeaed3f55b12a9881bcc0bb6c13f Mon Sep 17 00:00:00 2001 From: Tristan Fillinger Date: Tue, 15 Oct 2024 10:13:39 +0900 Subject: [PATCH] correct ax.legend() order for stack hists --- src/mplhep/plot.py | 15 +++++++++++++++ src/mplhep/utils.py | 4 +++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/mplhep/plot.py b/src/mplhep/plot.py index 7f8e215e..ad8e504c 100644 --- a/src/mplhep/plot.py +++ b/src/mplhep/plot.py @@ -2,6 +2,7 @@ import collections.abc import inspect +import itertools import logging from collections import OrderedDict from typing import TYPE_CHECKING, Any, NamedTuple, Union @@ -378,6 +379,20 @@ def iterable_not_string(arg): _chunked_kwargs = [_chunked_kwargs[ix] for ix in order] _labels = [_labels[ix] for ix in order] + elif stack: + # Sort from top to bottom so ax.legend() works as expected + order = np.argsort(label)[::-1] + plottables = [plottables[ix] for ix in order] + _chunked_kwargs = [_chunked_kwargs[ix] for ix in order] + _labels = [_labels[ix] for ix in order] + if "color" not in kwargs: + # Inverse default color cycle + _colors = itertools.cycle( + plt.rcParams["axes.prop_cycle"][len(plottables) - 1 :: -1] + ) + for i in range(len(plottables)): + _chunked_kwargs[i].update(next(_colors)) + # ############################ # # Stacking, norming, density if density is True and binwnorm is not None: diff --git a/src/mplhep/utils.py b/src/mplhep/utils.py index 68468d33..c1935f3e 100644 --- a/src/mplhep/utils.py +++ b/src/mplhep/utils.py @@ -294,6 +294,8 @@ def to_errorbar(self): def stack(*plottables): + # Sort from top to bottom so ax.legend() works as expected + plottables = plottables[::-1] baseline = np.nan_to_num(copy.deepcopy(plottables[0].values), 0) for i in range(1, len(plottables)): _mask = np.isnan(plottables[i].values) @@ -303,7 +305,7 @@ def stack(*plottables): baseline += np.nan_to_num(plottables[i].values, 0) plottables[i].values = np.nansum([plottables[i].values, _baseline], axis=0) plottables[i].values[_mask] = np.nan - return plottables + return plottables[::-1] def align_marker(