diff --git a/src/mplhep/plot.py b/src/mplhep/plot.py index ad8e504c..dfb9e269 100644 --- a/src/mplhep/plot.py +++ b/src/mplhep/plot.py @@ -379,20 +379,6 @@ def iterable_not_string(arg): _chunked_kwargs = [_chunked_kwargs[ix] for ix in order] _labels = [_labels[ix] for ix in order] - elif stack: - # Sort from top to bottom so ax.legend() works as expected - order = np.argsort(label)[::-1] - plottables = [plottables[ix] for ix in order] - _chunked_kwargs = [_chunked_kwargs[ix] for ix in order] - _labels = [_labels[ix] for ix in order] - if "color" not in kwargs: - # Inverse default color cycle - _colors = itertools.cycle( - plt.rcParams["axes.prop_cycle"][len(plottables) - 1 :: -1] - ) - for i in range(len(plottables)): - _chunked_kwargs[i].update(next(_colors)) - # ############################ # # Stacking, norming, density if density is True and binwnorm is not None: @@ -424,6 +410,19 @@ def iterable_not_string(arg): ########## # Plotting return_artists: list[StairsArtists | ErrorBarArtists] = [] + + if stack: + plottables = plottables[::-1] + _chunked_kwargs = _chunked_kwargs[::-1] + _labels = _labels[::-1] + if "color" not in kwargs: + # Inverse default color cycle + _colors = itertools.cycle( + plt.rcParams["axes.prop_cycle"][len(plottables) - 1 :: -1] + ) + for i in range(len(plottables)): + _chunked_kwargs[i].update(next(_colors)) + if histtype == "step": for i in range(len(plottables)): do_errors = yerr is not False and ( diff --git a/src/mplhep/utils.py b/src/mplhep/utils.py index c1935f3e..68468d33 100644 --- a/src/mplhep/utils.py +++ b/src/mplhep/utils.py @@ -294,8 +294,6 @@ def to_errorbar(self): def stack(*plottables): - # Sort from top to bottom so ax.legend() works as expected - plottables = plottables[::-1] baseline = np.nan_to_num(copy.deepcopy(plottables[0].values), 0) for i in range(1, len(plottables)): _mask = np.isnan(plottables[i].values) @@ -305,7 +303,7 @@ def stack(*plottables): baseline += np.nan_to_num(plottables[i].values, 0) plottables[i].values = np.nansum([plottables[i].values, _baseline], axis=0) plottables[i].values[_mask] = np.nan - return plottables[::-1] + return plottables def align_marker(