From 654fb48888fed4a829a0255e5402b94b67e62e1e Mon Sep 17 00:00:00 2001 From: Larissa Heinrich Date: Thu, 10 Oct 2024 14:29:37 -0400 Subject: [PATCH] feat: :sparkles: allow explicit mapping of axes for retrieving scale level --- src/cellmap_utils_kit/attribute_handler.py | 29 ++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/cellmap_utils_kit/attribute_handler.py b/src/cellmap_utils_kit/attribute_handler.py index 2b5cbe4..118cc97 100644 --- a/src/cellmap_utils_kit/attribute_handler.py +++ b/src/cellmap_utils_kit/attribute_handler.py @@ -52,14 +52,36 @@ def get_res_dict_from_attrs( return result +def get_axisorder( + attrs: h5py.AttributeManager | zarr.attrs.Attributes | dict, +) -> tuple[str]: + """Retrieve the axis ordering as a tuple of strings with the axes names. + + Args: + attrs (h5py.AttributeManager | zarr.attrs.Attributes | dict): _description_ + + Returns: + tuple[str]: Names of the axes in the order that they're referenced in in the + metadata, e.g. ("z", "y", "x) + + """ + ms_attrs = access_attributes(attrs["multiscales"]) + axis_order = [] + for ax in ms_attrs[0]["axes"]: + axis_order.append(ax["name"]) + return tuple(axis_order) + + def get_scalelevel( - group: h5py.Group | zarr.Group, request_scale: Sequence[float] + group: h5py.Group | zarr.Group, request_scale: Sequence[float] | dict[str, float] ) -> str: """Find the name of the array in a multiscale pyramid that has a specific scale. Args: group (h5py.Group | zarr.Group): multiscale group - request_scale (Sequence[float]): scale of the array you're looking for + request_scale (Sequence[float] | dict[str, float]): scale of the array you're + looking for, can be a sequence of values, assuming the same axis order as in + the attributes or a dictionary from axis name to scale value. Raises: ValueError: If that scale is not in the scale pyramid. @@ -70,6 +92,9 @@ def get_scalelevel( """ scales = get_res_dict_from_attrs(group.attrs) ref_scale = None + if isinstance(request_scale, dict): + axis_order = get_axisorder(group.attrs) + request_scale = [request_scale[ax] for ax in axis_order] for sclvl, scale in scales.items(): if tuple(scale) == tuple(request_scale): ref_scale = sclvl