From 885661dd615f4234725bbf9063d407954b69e78a Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:00:54 -0700 Subject: [PATCH] Refactor NGFF module and migrate to Pydantic v2 (#233) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * unify comment format * fix typing and docstring * create ngff sub-package and refactor display util file * refactor ngff meta file * refactor ngff * export transformation model * fix type hint * bump ome-zarr target in docstring * migrate to pydantic v2 * isort * fix validators * remove union type * fix dependency * update docstring * typing improvements * Update iohub/ngff/models.py Co-authored-by: Jordão Bragantini * fix style * update module docstring to specify their content --------- Co-authored-by: Jordão Bragantini --- iohub/convert.py | 3 +- iohub/ngff/__init__.py | 4 + iohub/{display_utils.py => ngff/display.py} | 20 +- iohub/{ngff_meta.py => ngff/models.py} | 368 ++++++++++---------- iohub/{ngff.py => ngff/nodes.py} | 136 ++++---- iohub/reader.py | 2 +- setup.cfg | 3 +- tests/_deprecated/test_zarrfile.py | 1 - tests/cli/test_cli.py | 5 +- tests/ngff/test_ngff.py | 6 +- tests/pyramid/test_pyramid.py | 8 +- tests/test_display_utils.py | 2 +- 12 files changed, 290 insertions(+), 268 deletions(-) create mode 100644 iohub/ngff/__init__.py rename iohub/{display_utils.py => ngff/display.py} (85%) rename iohub/{ngff_meta.py => ngff/models.py} (54%) rename iohub/{ngff.py => ngff/nodes.py} (94%) diff --git a/iohub/convert.py b/iohub/convert.py index d810fbba..4000bbc3 100644 --- a/iohub/convert.py +++ b/iohub/convert.py @@ -10,7 +10,8 @@ from tqdm.contrib.logging import logging_redirect_tqdm from iohub._version import version as iohub_version -from iohub.ngff import Position, TransformationMeta, open_ome_zarr +from iohub.ngff.models import TransformationMeta +from iohub.ngff.nodes import Position, open_ome_zarr from iohub.reader import MMStack, NDTiffDataset, read_images __all__ = ["TIFFConverter"] diff --git a/iohub/ngff/__init__.py b/iohub/ngff/__init__.py new file mode 100644 index 00000000..901df528 --- /dev/null +++ b/iohub/ngff/__init__.py @@ -0,0 +1,4 @@ +from iohub.ngff.models import TransformationMeta +from iohub.ngff.nodes import Plate, Position, open_ome_zarr + +__all__ = ["open_ome_zarr", "Plate", "Position", "TransformationMeta"] diff --git a/iohub/display_utils.py b/iohub/ngff/display.py similarity index 85% rename from iohub/display_utils.py rename to iohub/ngff/display.py index 578c101c..e399de27 100644 --- a/iohub/display_utils.py +++ b/iohub/ngff/display.py @@ -1,9 +1,9 @@ -""" Utility functions for displaying data """ +"""OME-Zarr display settings (OMERO metadata)""" import numpy as np from PIL.ImageColor import colormap -from iohub.ngff_meta import ChannelMeta, WindowDict +from iohub.ngff.models import ChannelMeta, WindowDict """ Dictionary with key works and most popular fluorescent probes """ CHANNEL_COLORS = { @@ -25,8 +25,10 @@ # emission around 440 - 460 nmm "blue": ["Blue", "DAPI", "BFP", "Hoechst"], "red": ["Red"], - "yellow": ["Yellow", "Cy3"], # Emission around 540-570 nm - "orange": ["Orange", "Cy5", "Y5"], # emission around 650-680 nm + # Emission around 540-570 nm + "yellow": ["Yellow", "Cy3"], + # emission around 650-680 nm + "orange": ["Orange", "Cy5", "Y5"], } @@ -49,12 +51,12 @@ def color_to_hex(color: str) -> str: def channel_display_settings( chan_name: str, - clim: tuple[float, float, float, float] = None, + clim: tuple[float, float, float, float] | None = None, first_chan: bool = False, ): - """This will create a dictionary used for OME-zarr metadata. - Allows custom contrast limits and channel. - names for display. Defaults everything to grayscale. + """This will create a dictionary used for OME-Zarr metadata. + Allows custom contrast limits and channel names for display. + Defaults everything to grayscale. Parameters ---------- @@ -85,7 +87,7 @@ def channel_display_settings( "S3": (-1.0, 1.0, -10.0, -10.0), "Other": (0, U16_FMAX, 0.0, U16_FMAX), } - if not clim: + if clim is None: if chan_name in channel_settings.keys(): clim = channel_settings[chan_name] else: diff --git a/iohub/ngff_meta.py b/iohub/ngff/models.py similarity index 54% rename from iohub/ngff_meta.py rename to iohub/ngff/models.py index c424ca8e..d94fb5dc 100644 --- a/iohub/ngff_meta.py +++ b/iohub/ngff/models.py @@ -1,6 +1,8 @@ +from __future__ import annotations + """ Data model classes with validation for OME-NGFF metadata. -Developed against OME-NGFF v0.4 and ome-zarr v0.6 +Developed against OME-NGFF v0.4 and ome-zarr v0.9 Attributes are 'snake_case' with aliases to match NGFF names in JSON output. See https://ngff.openmicroscopy.org/0.4/index.html#naming-style @@ -8,25 +10,42 @@ """ import re -from typing import Any, ClassVar, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional import pandas as pd -from pydantic import BaseModel, Field, root_validator, validator -from pydantic.color import Color, ColorType +from pydantic import ( + AfterValidator, + BaseModel, + ConfigDict, + Field, + NonNegativeInt, + PositiveInt, + field_validator, + model_validator, +) +from pydantic_extra_types.color import Color, ColorType + +# TODO: remove when drop Python < 3.12 +from typing_extensions import Self, TypedDict def unique_validator( - data: list[Union[BaseModel, TypedDict]], field: Union[str, list[str]] -): + data: list[BaseModel], field: str | list[str] +) -> list[BaseModel]: """Called by validators to ensure the uniqueness of certain fields. Parameters ---------- - data : list[Union[BaseModel, TypedDict]] + data : list[BaseModel] list of pydantic models or typed dictionaries - field : Union[str, list[str]] + field : str | list[str] field(s) of the dataclass that must be unique + Returns + ------- + list[BaseModel] + valid input data + Raises ------ ValueError @@ -34,14 +53,15 @@ def unique_validator( """ fields = [field] if isinstance(field, str) else field if not isinstance(data[0], dict): - data = [d.dict() for d in data] - df = pd.DataFrame(data) + params = [d.model_dump() for d in data] + df = pd.DataFrame(params) for key in fields: if not df[key].is_unique: raise ValueError(f"'{key}' must be unique!") + return data -def alpha_numeric_validator(data: str): +def alpha_numeric_validator(data: str) -> str: """Called by validators to ensure that strings are alpha-numeric. Parameters @@ -49,105 +69,112 @@ def alpha_numeric_validator(data: str): data : str string to check + Returns + ------- + str + valid input data + Raises ------ ValueError raised if the string contains characters other than [a-zA-z0-9] """ - if not (data.isalnum() or data.isnumeric()): + if not data.isalnum(): raise ValueError( f"The path name must be alphanumerical! Got: '{data}'." ) + return data TO_DICT_SETTINGS = dict(exclude_none=True, by_alias=True) class MetaBase(BaseModel): - class Config: - allow_population_by_field_name = True - + model_config = ConfigDict(populate_by_name=True) -class AxisMeta(MetaBase): - """https://ngff.openmicroscopy.org/0.4/index.html#axes-md""" - # MUST +class NamedAxisMeta(MetaBase): name: str - # SHOULD - type: Optional[Literal["space", "time", "channel"]] - # SHOULD - unit: Optional[str] - # store constants as class variables - SPACE_UNITS: ClassVar[set] = { - "angstrom", - "attometer", - "centimeter", - "decimeter", - "exameter", - "femtometer", - "foot", - "gigameter", - "hectometer", - "inch", - "kilometer", - "megameter", - "meter", - "micrometer", - "mile", - "millimeter", - "nanometer", - "parsec", - "petameter", - "picometer", - "terameter", - "yard", - "yoctometer", - "yottameter", - "zeptometer", - "zettameter", - } - TIME_UNITS: ClassVar[set] = { - "attosecond", - "centisecond", - "day", - "decisecond", - "exasecond", - "femtosecond", - "gigasecond", - "hectosecond", - "hour", - "kilosecond", - "megasecond", - "microsecond", - "millisecond", - "minute", - "nanosecond", - "petasecond", - "picosecond", - "second", - "terasecond", - "yoctosecond", - "yottasecond", - "zeptosecond", - "zettasecond", - } - - @validator("unit") - def valid_unit(cls, v, values: dict): - if values.get("type") and v is not None: - if values["type"] == "channel": - raise ValueError( - f"Channel axis must not have a unit! Got unit: {v}." - ) - if values["type"] == "space" and v not in cls.SPACE_UNITS: - raise ValueError( - f"Got invalid space unit: '{v}' not in {cls.SPACE_UNITS}" - ) - if values["type"] == "time" and v not in cls.TIME_UNITS: - raise ValueError( - f"Got invalid time unit: '{v}' not in {cls.TIME_UNITS}" - ) - return v + + +class ChannelAxisMeta(NamedAxisMeta): + type: Literal["channel"] = "channel" + + +class SpaceAxisMeta(NamedAxisMeta): + type: Literal["space"] = "space" + unit: ( + Literal[ + "angstrom", + "attometer", + "centimeter", + "decimeter", + "exameter", + "femtometer", + "foot", + "gigameter", + "hectometer", + "inch", + "kilometer", + "megameter", + "meter", + "micrometer", + "mile", + "millimeter", + "nanometer", + "parsec", + "petameter", + "picometer", + "terameter", + "yard", + "yoctometer", + "yottameter", + "zeptometer", + "zettameter", + ] + | None + ) + + +class TimeAxisMeta(NamedAxisMeta): + type: Literal["time"] = "time" + unit: ( + Literal[ + "attosecond", + "centisecond", + "day", + "decisecond", + "exasecond", + "femtosecond", + "gigasecond", + "hectosecond", + "hour", + "kilosecond", + "megasecond", + "microsecond", + "millisecond", + "minute", + "nanosecond", + "petasecond", + "picosecond", + "second", + "terasecond", + "yoctosecond", + "yottasecond", + "zeptosecond", + "zettasecond", + ] + | None + ) + + +class NonstandardAxisMeta(NamedAxisMeta): + type: str | None + unit: str | None + + +"""https://ngff.openmicroscopy.org/0.4/index.html#axes-md""" +AxisMeta = TimeAxisMeta | ChannelAxisMeta | SpaceAxisMeta | NonstandardAxisMeta class TransformationMeta(MetaBase): @@ -156,22 +183,26 @@ class TransformationMeta(MetaBase): # MUST type: Literal["identity", "translation", "scale"] # MUST? (keyword not found in spec for the fields below) - translation: Optional[list[float]] = None - scale: Optional[list[float]] = None - path: Optional[str] = None - - @root_validator - def no_extra_method(cls, values: dict): - count = sum([bool(v) for _, v in values.items()]) - if values["type"] == "identity" and count > 1: + translation: list[float] | None = None + scale: list[float] | None = None + path: Annotated[str, Field(min_length=1)] | None = None + + @model_validator(mode="after") + def no_extra_method(self) -> Self: + methods = sum( + bool(m is not None) + for m in [self.translation, self.scale, self.path] + ) + if self.type == "identity" and methods > 0: raise ValueError( "Method should not be specified for identity transformation!" ) - elif count > 2: + elif self.translation and self.scale: raise ValueError( - "Only one type of transformation method is allowed." + "'translation' and 'scale' cannot be provided " + f"in the same `{type(self).__name__}`!" ) - return values + return self class DatasetMeta(MetaBase): @@ -181,7 +212,7 @@ class DatasetMeta(MetaBase): path: str # MUST coordinate_transformations: list[TransformationMeta] = Field( - alias="coordinateTransformations" + alias=str("coordinateTransformations") ) @@ -189,28 +220,29 @@ class VersionMeta(MetaBase): """OME-NGFF spec version. Default is the current version (0.4).""" # SHOULD - version: Optional[Literal["0.1", "0.2", "0.3", "0.4"]] + version: Literal["0.1", "0.2", "0.3", "0.4"] = "0.4" class MultiScaleMeta(VersionMeta): """https://ngff.openmicroscopy.org/0.4/index.html#multiscale-md""" # MUST - axes: list[AxisMeta] + axes: list[AxisMeta] = Field(..., discriminator="type") # MUST datasets: list[DatasetMeta] # SHOULD - name: Optional[str] = None + name: str | None = None # MAY - coordinate_transformations: Optional[list[TransformationMeta]] = Field( - alias="coordinateTransformations" + coordinate_transformations: list[TransformationMeta] | None = Field( + alias=str("coordinateTransformations"), default=None ) # SHOULD, describes the downscaling method (e.g. 'gaussian') - type: Optional[str] = None + type: str | None = None # SHOULD, additional information about the downscaling method - metadata: Optional[dict] = None + metadata: dict | None = None - @validator("axes") + @field_validator("axes") + @classmethod def unique_name(cls, v): unique_validator(v, "name") return v @@ -238,8 +270,8 @@ class ChannelMeta(MetaBase): color: ColorType = "FFFFFF" family: str = "linear" inverted: bool = False - label: str = None - window: WindowDict = None + label: str | None = None + window: WindowDict | None = None class RDefsMeta(MetaBase): @@ -257,9 +289,9 @@ class OMEROMeta(VersionMeta): """https://ngff.openmicroscopy.org/0.4/index.html#omero-md""" id: int - name: Optional[str] - channels: Optional[list[ChannelMeta]] - rdefs: Optional[RDefsMeta] + name: str | None = None + channels: list[ChannelMeta] | None = None + rdefs: RDefsMeta | None = None class ImagesMeta(MetaBase): @@ -284,9 +316,10 @@ class LabelColorMeta(MetaBase): # MUST label_value: int = Field(alias="label-value") # MAY - rgba: Optional[ColorType] = None + rgba: ColorType | None = None - @validator("rgba") + @field_validator("rgba") + @classmethod def rgba_color(cls, v): v = Color(v).as_rgb_tuple(alpha=True) return v @@ -302,7 +335,8 @@ class ImageLabelMeta(VersionMeta): # MAY source: dict[str, Any] - @validator("colors", "properties") + @field_validator("colors", "properties") + @classmethod def unique_label_value(cls, v): # MUST unique_validator(v, "label_value") @@ -313,36 +347,30 @@ class AcquisitionMeta(MetaBase): """https://ngff.openmicroscopy.org/0.4/index.html#plate-md""" # MUST - id: int + id: NonNegativeInt # SHOULD - name: Optional[str] = None + name: str | None = None # SHOULD - maximum_field_count: Optional[int] = Field(alias="maximumfieldcount") + maximum_field_count: PositiveInt | None = Field( + alias="maximumfieldcount", default=None + ) # MAY - description: Optional[str] = None + description: str | None = None # MAY - start_time: Optional[int] = Field(alias="starttime") + start_time: NonNegativeInt | None = Field( + alias=str("starttime"), default=None + ) # MAY - end_time: Optional[int] = Field(alias="endtime") - - @validator("id", "maximum_field_count", "start_time", "end_time") - def geq_zero(cls, v): - # MUST - if v < 0: - raise ValueError( - "The integer value must be equal or greater to zero!" - ) - return v + end_time: NonNegativeInt | None = Field(alias=str("endtime"), default=None) - @validator("end_time") - def end_after_start(cls, v: int, values: dict): - # CUSTOM - if st := values.get("start_time"): - if st > v: + @model_validator(mode="after") + def end_after_start(self) -> Self: + if self.start_time is not None and self.end_time is not None: + if self.start_time > self.end_time: raise ValueError( - f"Start timestamp {st} is larger than end timestamp {v}." + "The acquisition end time must be after the start time!" ) - return v + return self class PlateAxisMeta(MetaBase): @@ -350,13 +378,7 @@ class PlateAxisMeta(MetaBase): https://ngff.openmicroscopy.org/0.4/index.html#plate-md""" # MUST - name: str - - @validator("name") - def alpha_numeric(cls, v: str): - # MUST - alpha_numeric_validator(v) - return v + name: Annotated[str, AfterValidator(alpha_numeric_validator)] class WellIndexMeta(MetaBase): @@ -364,10 +386,11 @@ class WellIndexMeta(MetaBase): https://ngff.openmicroscopy.org/0.4/index.html#plate-md""" path: str - row_index: int = Field(alias="rowIndex") - column_index: int = Field(alias="columnIndex") + row_index: NonNegativeInt = Field(alias="rowIndex") + column_index: NonNegativeInt = Field(alias="columnIndex") - @validator("path") + @field_validator("path") + @classmethod def row_slash_column(cls, v: str): # MUST # regex: one line that is exactly two words separated by one '/' @@ -377,22 +400,15 @@ def row_slash_column(cls, v: str): ) return v - @validator("row_index", "column_index") - def geq_zero(cls, v: int): - # MUST - if v < 0: - raise ValueError("Well position indices must not be negative!") - return v - class PlateMeta(VersionMeta): """OME-NGFF high-content screening plate metadata. https://ngff.openmicroscopy.org/0.4/index.html#plate-md""" # SHOULD - name: Optional[str] + name: str | None = None # MAY - acquisitions: Optional[list[AcquisitionMeta]] + acquisitions: list[AcquisitionMeta] | None = None # MUST rows: list[PlateAxisMeta] # MUST @@ -400,48 +416,38 @@ class PlateMeta(VersionMeta): # MUST wells: list[WellIndexMeta] # SHOULD - field_count: Optional[int] + field_count: PositiveInt | None = None - @validator("acquisitions") + @field_validator("acquisitions") + @classmethod def unique_id(cls, v): # MUST unique_validator(v, "id") return v - @validator("rows", "columns") + @field_validator("rows", "columns") + @classmethod def unique_name(cls, v): # MUST unique_validator(v, "name") return v - @validator("wells") + @field_validator("wells") + @classmethod def unique_well(cls, v): # CUSTOM unique_validator(v, "path") return v - @validator("field_count") - def positive(cls, v): - # MUST - if v <= 0: - raise ValueError("Field count must be a positive integer!") - return v - class ImageMeta(MetaBase): """Image metadata field under an HCS well group. https://ngff.openmicroscopy.org/0.4/index.html#well-md""" # MUST if `PlateMeta.acquisitions` contains multiple acquisitions - acquisition: Optional[int] + acquisition: int | None = None # MUST - path: str - - @validator("path") - def alpha_numeric(cls, v): - # MUST - alpha_numeric_validator(v) - return v + path: Annotated[str, AfterValidator(alpha_numeric_validator)] class WellGroupMeta(VersionMeta): diff --git a/iohub/ngff.py b/iohub/ngff/nodes.py similarity index 94% rename from iohub/ngff.py rename to iohub/ngff/nodes.py index d5a307e0..6491205c 100644 --- a/iohub/ngff.py +++ b/iohub/ngff/nodes.py @@ -1,3 +1,7 @@ +""" +Node object and convenience functions for the OME-NGFF (OME-Zarr) Hierarchy. +""" + # TODO: remove this in the future (PEP deferred for 3.11, now 3.12?) from __future__ import annotations @@ -5,7 +9,7 @@ import math import os from copy import deepcopy -from typing import TYPE_CHECKING, Generator, Literal, Sequence, Union +from typing import TYPE_CHECKING, Generator, Literal, Sequence, Type import numpy as np import zarr @@ -14,11 +18,12 @@ from pydantic import ValidationError from zarr.util import normalize_storage_path -from iohub.display_utils import channel_display_settings -from iohub.ngff_meta import ( +from iohub.ngff.display import channel_display_settings +from iohub.ngff.models import ( TO_DICT_SETTINGS, AcquisitionMeta, AxisMeta, + ChannelAxisMeta, DatasetMeta, ImageMeta, ImagesMeta, @@ -27,6 +32,8 @@ PlateAxisMeta, PlateMeta, RDefsMeta, + SpaceAxisMeta, + TimeAxisMeta, TransformationMeta, WellGroupMeta, WellIndexMeta, @@ -56,10 +63,8 @@ def _open_store( ) if version != "0.4": _logger.warning( - "\n".join( - "IOHub is only tested against OME-NGFF v0.4.", - f"Requested version {version} may not work properly.", - ) + "IOHub is only tested against OME-NGFF v0.4. " + f"Requested version {version} may not work properly." ) dimension_separator = None else: @@ -84,22 +89,19 @@ def _scale_integers(values: Sequence[int], factor: int) -> tuple[int, ...]: class NGFFNode: """A node (group level in Zarr) in an NGFF dataset.""" - _MEMBER_TYPE = None + _MEMBER_TYPE: Type[NGFFNode] _DEFAULT_AXES = [ - AxisMeta(name="T", type="time", unit="second"), - AxisMeta(name="C", type="channel"), - *[ - AxisMeta(name=i, type="space", unit="micrometer") - for i in ("Z", "Y", "X") - ], + TimeAxisMeta(name="T", unit="second"), + ChannelAxisMeta(name="C"), + *[SpaceAxisMeta(name=i, unit="micrometer") for i in ("Z", "Y", "X")], ] def __init__( self, group: zarr.Group, parse_meta: bool = True, - channel_names: list[str] = None, - axes: list[AxisMeta] = None, + channel_names: list[str] | None = None, + axes: list[AxisMeta] | None = None, version: Literal["0.1", "0.4"] = "0.4", overwriting_creation: bool = False, ): @@ -240,7 +242,7 @@ def is_leaf(self): """ return not self.group_keys() - def print_tree(self, level: int = None): + def print_tree(self, level: int | None = None): """Print hierarchy of the node to stdout. Parameters @@ -375,8 +377,8 @@ def get_tile( self, row: int, column: int, - pre_dims: tuple[Union[int, slice, None]] = None, - ): + pre_dims: tuple[int | slice, ...] | None = None, + ) -> NDArray: """Get a tile as an up-to-5D in-RAM NumPy array. Parameters @@ -385,7 +387,7 @@ def get_tile( Row index. column : int Column index. - pre_dims : tuple[Union[int, slice, None]], optional + pre_dims : tuple[int | slice, ...], optional Indices or slices for previous dimensions than rows and columns with matching shape, e.g. (t, c, z) for 5D arrays, by default None (select all). @@ -402,8 +404,8 @@ def write_tile( data: ArrayLike, row: int, column: int, - pre_dims: tuple[Union[int, slice, None]] = None, - ): + pre_dims: tuple[int | slice, ...] | None = None, + ) -> None: """Write a tile in the Zarr store. Parameters @@ -414,7 +416,7 @@ def write_tile( Row index. column : int Column index. - pre_dims : tuple[Union[int, slice, None]], optional + pre_dims : tuple[int | slice, ...], optional Indices or slices for previous dimensions than rows and columns with matching shape, e.g. (t, c, z) for 5D arrays, by default None (select all). @@ -426,8 +428,8 @@ def get_tile_slice( self, row: int, column: int, - pre_dims: tuple[Union[int, slice, None]] = None, - ): + pre_dims: tuple[int | slice, ...] | None = None, + ) -> tuple[slice, ...]: """Get the slices for a tile in the underlying array. Parameters @@ -436,14 +438,14 @@ def get_tile_slice( Row index. column : int Column index. - pre_dims : tuple[Union[int, slice, None]], optional + pre_dims : tuple[int | slice, ...], optional Indices or slices for previous dimensions than rows and columns with matching shape, e.g. (t, c, z) for 5D arrays, by default None (select all). Returns ------- - tuple[slice] + tuple[slice, ...] Tuple of slices for all the dimensions of the array. """ self._check_rc(row, column) @@ -464,9 +466,11 @@ def get_tile_slice( f"got type {type(pre_dims)}." ) for i, sel in enumerate(pre_dims): + if isinstance(sel, int): + sel = slice(sel) if sel is not None: pad[i] = sel - return tuple(pad) + (r_slice, c_slice) + return tuple((pad + [r_slice, c_slice])) @staticmethod def _check_rc(row: int, column: int): @@ -511,8 +515,8 @@ def __init__( self, group: zarr.Group, parse_meta: bool = True, - channel_names: list[str] = None, - axes: list[AxisMeta] = None, + channel_names: list[str] | None = None, + axes: list[AxisMeta] | None = None, version: Literal["0.1", "0.4"] = "0.4", overwriting_creation: bool = False, ): @@ -544,7 +548,7 @@ def _parse_meta(self): def dump_meta(self): """Dumps metadata JSON to the `.zattrs` file.""" - self.zattrs.update(**self.metadata.dict(**TO_DICT_SETTINGS)) + self.zattrs.update(**self.metadata.model_dump(**TO_DICT_SETTINGS)) @property def _storage_options(self): @@ -592,13 +596,13 @@ def data(self): f"in the group of: {self.array_keys()}" ) - def __getitem__(self, key: Union[int, str]): + def __getitem__(self, key: int | str) -> ImageArray: """Get an image array member of the position. E.g. Raw-coordinates image, a multi-scale level, or labels Parameters ---------- - key : Union[int, str] + key : int| str Name or path to the image array. Integer key is converted to string (name). @@ -633,8 +637,8 @@ def create_image( self, name: str, data: NDArray, - chunks: tuple[int] = None, - transform: list[TransformationMeta] = None, + chunks: tuple[int] | None = None, + transform: list[TransformationMeta] | None = None, check_shape: bool = True, ): """Create a new image array in the position. @@ -677,8 +681,8 @@ def create_zeros( name: str, shape: tuple[int], dtype: DTypeLike, - chunks: tuple[int] = None, - transform: list[TransformationMeta] = None, + chunks: tuple[int] | None = None, + transform: list[TransformationMeta] | None = None, check_shape: bool = True, ): """Create a new zero-filled image array in the position. @@ -757,8 +761,8 @@ def _check_shape(self, data_shape: tuple[int]): def _create_image_meta( self, name: str, - transform: list[TransformationMeta] = None, - extra_meta: dict = None, + transform: list[TransformationMeta] | None = None, + extra_meta: dict | None = None, ): if not transform: transform = [TransformationMeta(type="identity")] @@ -792,7 +796,7 @@ def _omero_meta( self, id: int, name: str, - clims: list[tuple[float, float, float, float]] = None, + clims: list[tuple[float, float, float, float]] | None = None, ): if not clims: clims = [None] * len(self.channel_names) @@ -859,7 +863,7 @@ def append_channel(self, chan_name: str, resize_arrays: bool = True): f"Cannot infer channel axis for shape {shape}." ) img.resize(shape) - if "omero" in self.metadata.dict().keys(): + if "omero" in self.metadata.model_dump().keys(): self.metadata.omero.channels.append( channel_display_settings(chan_name) ) @@ -974,7 +978,7 @@ def scale(self) -> list[float]: def set_transform( self, - image: Union[str, Literal["*"]], + image: str | Literal["*"], transform: list[TransformationMeta], ): """Set the coordinate transformations metadata @@ -982,7 +986,7 @@ def set_transform( Parameters ---------- - image : Union[str, Literal["*"]] + image : str | Literal["*"] Name of one image array (e.g. "0") to transform, or "*" for the whole FOV transform : list[TransformationMeta] @@ -1019,7 +1023,7 @@ def make_tiles( grid_shape: tuple[int, int], tile_shape: tuple[int], dtype: DTypeLike, - transform: list[TransformationMeta] = None, + transform: list[TransformationMeta] | None = None, chunk_dims: int = 2, ): """Make a tiled image array filled with zeros. @@ -1093,8 +1097,8 @@ def __init__( self, group: zarr.Group, parse_meta: bool = True, - channel_names: list[str] = None, - axes: list[AxisMeta] = None, + channel_names: list[str] | None = None, + axes: list[AxisMeta] | None = None, version: Literal["0.1", "0.4"] = "0.4", overwriting_creation: bool = False, ): @@ -1115,7 +1119,9 @@ def _parse_meta(self): def dump_meta(self): """Dumps metadata JSON to the `.zattrs` file.""" - self.zattrs.update({"well": self.metadata.dict(**TO_DICT_SETTINGS)}) + self.zattrs.update( + {"well": self.metadata.model_dump(**TO_DICT_SETTINGS)} + ) def __getitem__(self, key: str): """Get a position member of the well. @@ -1195,8 +1201,8 @@ def __init__( self, group: zarr.Group, parse_meta: bool = True, - channel_names: list[str] = None, - axes: list[AxisMeta] = None, + channel_names: list[str] | None = None, + axes: list[AxisMeta] | None = None, version: Literal["0.1", "0.4"] = "0.4", overwriting_creation: bool = False, ): @@ -1320,10 +1326,10 @@ def __init__( self, group: zarr.Group, parse_meta: bool = True, - channel_names: list[str] = None, - axes: list[AxisMeta] = None, - name: str = None, - acquisitions: list[AcquisitionMeta] = None, + channel_names: list[str] | None = None, + axes: list[AxisMeta] | None = None, + name: str | None = None, + acquisitions: list[AcquisitionMeta] | None = None, version: Literal["0.1", "0.4"] = "0.4", overwriting_creation: bool = False, ): @@ -1380,12 +1386,14 @@ def dump_meta(self, field_count: bool = False): """ if field_count: self.metadata.field_count = len(list(self.positions())) - self.zattrs.update({"plate": self.metadata.dict(**TO_DICT_SETTINGS)}) + self.zattrs.update( + {"plate": self.metadata.model_dump(**TO_DICT_SETTINGS)} + ) def _auto_idx( self, - name: "str", - index: Union[int, None], + name: str, + index: int | None, axis_name: Literal["row", "column"], ): if index is not None: @@ -1422,8 +1430,8 @@ def create_well( self, row_name: str, col_name: str, - row_index: int = None, - col_index: int = None, + row_index: int | None = None, + col_index: int | None = None, ): """Creates a new well group in the plate. The new well will have empty group metadata, @@ -1528,7 +1536,7 @@ def create_position( ) return well.create_position(pos_name, acquisition=acq_index) - def rows(self): + def rows(self) -> Generator[tuple[str, Row], None, None]: """Returns a generator that iterate over the name and value of all the rows in the plate. @@ -1570,14 +1578,12 @@ def open_ome_zarr( store_path: StrOrBytesPath, layout: Literal["auto", "fov", "hcs", "tiled"] = "auto", mode: Literal["r", "r+", "a", "w", "w-"] = "r", - channel_names: list[str] = None, - axes: list[AxisMeta] = None, + channel_names: list[str] | None = None, + axes: list[AxisMeta] | None = None, version: Literal["0.1", "0.4"] = "0.4", - synchronizer: Union[ - zarr.ThreadSynchronizer, zarr.ProcessSynchronizer - ] = None, + synchronizer: zarr.ThreadSynchronizer | zarr.ProcessSynchronizer = None, **kwargs, -): +) -> Plate | Position | TiledPosition: """Convenience method to open OME-Zarr stores. Parameters diff --git a/iohub/reader.py b/iohub/reader.py index 708fb81e..69cb2460 100644 --- a/iohub/reader.py +++ b/iohub/reader.py @@ -16,7 +16,7 @@ from iohub.fov import BaseFOVMapping from iohub.mmstack import MMStack from iohub.ndtiff import NDTiffDataset -from iohub.ngff import NGFFNode, Plate, Position, open_ome_zarr +from iohub.ngff.nodes import NGFFNode, Plate, Position, open_ome_zarr if TYPE_CHECKING: from _typeshed import StrOrBytesPath diff --git a/setup.cfg b/setup.cfg index 28384572..947878a0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,8 @@ python_requires = >=3.10 setup_requires = setuptools_scm install_requires = pandas>=1.5.2 - pydantic>=1.10.2, <2 + pydantic>=2.8.2 + pydantic_extra_types>=2.9.0 tifffile>=2024.1.30 natsort>=7.1.1 ndtiff>=2.2.1 diff --git a/tests/_deprecated/test_zarrfile.py b/tests/_deprecated/test_zarrfile.py index 2f55f856..22f412d4 100644 --- a/tests/_deprecated/test_zarrfile.py +++ b/tests/_deprecated/test_zarrfile.py @@ -3,7 +3,6 @@ from iohub._deprecated.zarrfile import ZarrReader from iohub.reader import read_images - from tests.conftest import mm2gamma_zarr_v01 diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index b8ddf9c7..4f4f4b34 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -1,17 +1,16 @@ import re from unittest.mock import patch -from click.testing import CliRunner import pytest +from click.testing import CliRunner from iohub._version import __version__ from iohub.cli.cli import cli - from tests.conftest import ( + hcs_ref, mm2gamma_ome_tiffs, ndtiff_v2_datasets, ndtiff_v3_labeled_positions, - hcs_ref, ) diff --git a/tests/ngff/test_ngff.py b/tests/ngff/test_ngff.py index b280b9e7..d1701529 100644 --- a/tests/ngff/test_ngff.py +++ b/tests/ngff/test_ngff.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from _typeshed import StrPath -from iohub.ngff import ( +from iohub.ngff.nodes import ( TO_DICT_SETTINGS, Plate, TransformationMeta, @@ -366,7 +366,7 @@ def test_set_transform_image(ch_shape_dtype, arr_name): ext_reader = Reader(parse_url(dataset.zgroup.store.path)) node = list(ext_reader())[0] assert node.metadata["coordinateTransformations"][0] == [ - translate.dict(**TO_DICT_SETTINGS) for translate in transform + translate.model_dump(**TO_DICT_SETTINGS) for translate in transform ] @@ -399,7 +399,7 @@ def test_set_transform_fov(ch_shape_dtype, arr_name): # read data with plain zarr group = zarr.open(store_path) assert group.attrs["multiscales"][0]["coordinateTransformations"] == [ - translate.dict(**TO_DICT_SETTINGS) for translate in transform + translate.model_dump(**TO_DICT_SETTINGS) for translate in transform ] diff --git a/tests/pyramid/test_pyramid.py b/tests/pyramid/test_pyramid.py index 50b76c76..9464f497 100644 --- a/tests/pyramid/test_pyramid.py +++ b/tests/pyramid/test_pyramid.py @@ -5,8 +5,12 @@ from ome_zarr.io import parse_url from ome_zarr.reader import Multiscales, Reader -from iohub.ngff import Position, _pad_shape, open_ome_zarr -from iohub.ngff_meta import TransformationMeta +from iohub.ngff.nodes import ( + Position, + TransformationMeta, + _pad_shape, + open_ome_zarr, +) def _mock_fov( diff --git a/tests/test_display_utils.py b/tests/test_display_utils.py index 3f1ef55a..8cf1271c 100644 --- a/tests/test_display_utils.py +++ b/tests/test_display_utils.py @@ -1,4 +1,4 @@ -from iohub.display_utils import channel_display_settings +from iohub.ngff.display import channel_display_settings def test_channel_display_settings():