Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: factor get_plottables logic out of histplot #534

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mplhep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
yscale_legend,
)
from .styles import set_style
from .utils import get_plottables

# Configs
rcParams = Config(
Expand Down Expand Up @@ -76,4 +77,5 @@
"sort_legend",
"save_variations",
"set_style",
"get_plottables",
]
171 changes: 14 additions & 157 deletions src/mplhep/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from mpl_toolkits.axes_grid1 import axes_size, make_axes_locatable

from .utils import (
Plottable,
align_marker,
get_histogram_axes_title,
get_plottable_protocol_bins,
get_plottables,
hist_object_handler,
isLight,
process_histogram_parts,
Expand Down Expand Up @@ -198,86 +198,18 @@ def histplot(
else get_histogram_axes_title(hists[0].axes[0])
)

plottables = []
flow_bins = final_bins
for h in hists:
value, variance = np.copy(h.values()), h.variances()
if has_variances := variance is not None:
variance = np.copy(variance)
underflow, overflow = 0.0, 0.0
underflowv, overflowv = 0.0, 0.0
# One sided flow bins - hist (uproot hist does not have the over- or underflow traits)
if (
hasattr(h, "axes")
and (traits := getattr(h.axes[0], "traits", None)) is not None
and hasattr(traits, "underflow")
and hasattr(traits, "overflow")
):
if traits.overflow:
overflow = np.copy(h.values(flow=True))[-1]
if has_variances:
overflowv = np.copy(h.variances(flow=True))[-1]
if traits.underflow:
underflow = np.copy(h.values(flow=True))[0]
if has_variances:
underflowv = np.copy(h.variances(flow=True))[0]
# Both flow bins exist - uproot
elif hasattr(h, "values") and "flow" in inspect.getfullargspec(h.values).args:
if len(h.values()) + 2 == len(
h.values(flow=True)
): # easy case, both over/under
underflow, overflow = (
np.copy(h.values(flow=True))[0],
np.copy(h.values(flow=True))[-1],
)
if has_variances:
underflowv, overflowv = (
np.copy(h.variances(flow=True))[0],
np.copy(h.variances(flow=True))[-1],
)

# Set plottables
if flow in ("none", "hint"):
plottables.append(Plottable(value, edges=final_bins, variances=variance))
elif flow == "show":
_flow_bin_size: float = np.max(
[0.05 * (final_bins[-1] - final_bins[0]), np.mean(np.diff(final_bins))]
)
flow_bins = np.copy(final_bins)
if underflow > 0:
flow_bins = np.r_[flow_bins[0] - _flow_bin_size, flow_bins]
value = np.r_[underflow, value]
if has_variances:
variance = np.r_[underflowv, variance]
if overflow > 0:
flow_bins = np.r_[flow_bins, flow_bins[-1] + _flow_bin_size]
value = np.r_[value, overflow]
if has_variances:
variance = np.r_[variance, overflowv]
plottables.append(Plottable(value, edges=flow_bins, variances=variance))
elif flow == "sum":
if underflow > 0:
value[0] += underflow
if has_variances:
variance[0] += underflowv
if overflow > 0:
value[-1] += overflow
if has_variances:
variance[-1] += overflowv
plottables.append(Plottable(value, edges=final_bins, variances=variance))
else:
plottables.append(Plottable(value, edges=final_bins, variances=variance))

if w2 is not None:
for _w2, _plottable in zip(
w2.reshape(len(plottables), len(final_bins) - 1), plottables
):
_plottable.variances = _w2
_plottable.method = w2method

if w2 is not None and yerr is not None:
msg = "Can only supply errors or w2"
raise ValueError(msg)
plottables, flow_info = get_plottables(
hists,
bins=final_bins,
w2=w2,
w2method=w2method,
yerr=yerr,
stack=stack,
density=density,
binwnorm=binwnorm,
flow=flow,
)
flow_bins, underflow, overflow = flow_info

_labels: list[str | None]
if label is None:
Expand Down Expand Up @@ -311,52 +243,6 @@ def iterable_not_string(arg):
for i in range(len(_chunked_kwargs)):
_chunked_kwargs[i][kwarg] = kwargs[kwarg]

############################
# # yerr calculation
_yerr: np.ndarray | None
if yerr is not None:
# yerr is array
if hasattr(yerr, "__len__"):
_yerr = np.asarray(yerr)
# yerr is a number
elif isinstance(yerr, (int, float)) and not isinstance(yerr, bool):
_yerr = np.ones((len(plottables), len(final_bins) - 1)) * yerr
# yerr is automatic
else:
_yerr = None
else:
_yerr = None

if _yerr is not None:
assert isinstance(_yerr, np.ndarray)
if _yerr.ndim == 3:
# Already correct format
pass
elif _yerr.ndim == 2 and len(plottables) == 1:
# Broadcast ndim 2 to ndim 3
if _yerr.shape[-2] == 2: # [[1,1], [1,1]]
_yerr = _yerr.reshape(len(plottables), 2, _yerr.shape[-1])
elif _yerr.shape[-2] == 1: # [[1,1]]
_yerr = np.tile(_yerr, 2).reshape(len(plottables), 2, _yerr.shape[-1])
else:
msg = "yerr format is not understood"
raise ValueError(msg)
elif _yerr.ndim == 2:
# Broadcast yerr (nh, N) to (nh, 2, N)
_yerr = np.tile(_yerr, 2).reshape(len(plottables), 2, _yerr.shape[-1])
elif _yerr.ndim == 1:
# Broadcast yerr (1, N) to (nh, 2, N)
_yerr = np.tile(_yerr, 2 * len(plottables)).reshape(
len(plottables), 2, _yerr.shape[-1]
)
else:
msg = "yerr format is not understood"
raise ValueError(msg)

assert _yerr is not None
for yrs, _plottable in zip(_yerr, plottables):
_plottable.fixed_errors(*yrs)

# Sorting
if sort is not None:
if isinstance(sort, str):
Expand All @@ -379,34 +265,6 @@ def iterable_not_string(arg):
_chunked_kwargs = [_chunked_kwargs[ix] for ix in order]
_labels = [_labels[ix] for ix in order]

# ############################
# # Stacking, norming, density
if density is True and binwnorm is not None:
msg = "Can only set density or binwnorm."
raise ValueError(msg)
if density is True:
if stack:
_total = np.sum(
np.array([plottable.values for plottable in plottables]), axis=0
)
for plottable in plottables:
plottable.flat_scale(1.0 / np.sum(np.diff(final_bins) * _total))
else:
for plottable in plottables:
plottable.density = True
elif binwnorm is not None:
for plottable, norm in zip(
plottables, np.broadcast_to(binwnorm, (len(plottables),))
):
plottable.flat_scale(norm)
plottable.binwnorm()

# Stack
if stack and len(plottables) > 1:
from .utils import stack as stack_fun

plottables = stack_fun(*plottables)

##########
# Plotting
return_artists: list[StairsArtists | ErrorBarArtists] = []
Expand Down Expand Up @@ -443,8 +301,7 @@ def iterable_not_string(arg):
if "step" in histtype:
for i in range(len(plottables)):
do_errors = yerr is not False and (
(yerr is not None or w2 is not None)
or (plottables[i].variances is not None)
(yerr is not None or w2 is not None) or plottables[i]._has_variances
)

_kwargs = _chunked_kwargs[i]
Expand Down
Loading
Loading