Skip to content

Commit

Permalink
initial commit with added docs prior to refactoring to simplify the n…
Browse files Browse the repository at this point in the history
…aming and functions.
  • Loading branch information
edyoshikun authored and ieivanov committed Jul 18, 2024
1 parent 94729f0 commit 16d2b83
Showing 1 changed file with 372 additions and 0 deletions.
372 changes: 372 additions & 0 deletions iohub/ngff_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,372 @@
import contextlib
import inspect
import io
import itertools
import multiprocessing as mp
from functools import partial
from pathlib import Path
from typing import Tuple

import click
import numpy as np
from numpy.typing import DTypeLike

from iohub.ngff import Position, open_ome_zarr
from iohub.ngff_meta import TransformationMeta


def create_empty_hcs_zarr(
store_path: Path,
position_keys: list[Tuple[str]],
channel_names: list[str],
shape: Tuple[int],
chunks: Tuple[int] = None,
scale: Tuple[float] = (1, 1, 1, 1, 1),
dtype: DTypeLike = np.float32,
max_chunk_size_bytes=500e6,
) -> None:
"""
If the plate does not exist, create an empty zarr plate.
If the plate exists, append positions and channels if they are not
already in the plate.
Parameters
----------
store_path : Path
hcs plate path
position_keys : list[Tuple[str]]
Position keys, will append if not present in the plate.
e.g. [("A", "1", "0"), ("A", "1", "1")]
shape : Tuple[int]
chunks : Tuple[int]
scale : Tuple[float]
channel_names : list[str]
Channel names, will append if not present in metadata.
dtype : DTypeLike
Modifying from recOrder
https://github.com/mehta-lab/recOrder/blob/d31ad910abf84c65ba927e34561f916651cbb3e8/recOrder/cli/utils.py#L12
"""
MAX_CHUNK_SIZE = max_chunk_size_bytes # in bytes
bytes_per_pixel = np.dtype(dtype).itemsize

# Limiting the chunking to 500MB
if chunks is None:
chunk_zyx_shape = list(shape[-3:])
# XY image is larger than MAX_CHUNK_SIZE
while (
chunk_zyx_shape[-3] > 1
and np.prod(chunk_zyx_shape) * bytes_per_pixel > MAX_CHUNK_SIZE
):
chunk_zyx_shape[-3] = np.ceil(chunk_zyx_shape[-3] / 2).astype(int)
chunk_zyx_shape = tuple(chunk_zyx_shape)

chunks = 2 * (1,) + chunk_zyx_shape

# Create plate
output_plate = open_ome_zarr(
str(store_path), layout="hcs", mode="a", channel_names=channel_names
)

# Create positions
for position_key in position_keys:
position_key_string = "/".join(position_key)
# Check if position is already in the store, if not create it
if position_key_string not in output_plate.zgroup:
position = output_plate.create_position(*position_key)
_ = position.create_zeros(
name="0",
shape=shape,
chunks=chunks,
dtype=dtype,
transform=[TransformationMeta(type="scale", scale=scale)],
)
else:
position = output_plate[position_key_string]

# Check if channel_names are already in the store, if not append them
for channel_name in channel_names:
# Read channel names directly from metadata to avoid race conditions
metadata_channel_names = [
channel.label for channel in position.metadata.omero.channels
]
if channel_name not in metadata_channel_names:
position.append_channel(channel_name, resize_arrays=True)


def get_output_paths(
input_paths: list[Path], output_zarr_path: Path
) -> list[Path]:
"""Generates a mirrored output path list given an input list of positions"""
list_output_path = []
for path in input_paths:
# Select the Row/Column/FOV parts of input path
path_strings = Path(path).parts[-3:]
# Append the same Row/Column/FOV to the output zarr path
list_output_path.append(Path(output_zarr_path, *path_strings))
return list_output_path


def apply_transform_to_zyx_and_save_v2(
func,
position: Position,
output_path: Path,
input_channel_indices: list[int],
output_channel_indices: list[int],
t_idx: int,
t_idx_out: int,
c_idx: int = None,
**kwargs,
) -> None:
"""
Load a zyx array from a Position object, apply a transformation to CZYX or ZYX and save the result to file
Parameters
----------
func : CZYX -> CZYX function
The function to be applied to the data
position : Position
The position object to read from
output_path : Path
The path to output OME-Zarr Store
input_channel_indices : list
The channel indices to process.
If empty list, process all channels.
Must match output_channel_indices if not empty
output_channel_indices : list
The channel indices to write to.
If empty list, write to all channels.
Must match input_channel_indices if not empty
t_idx : int
The time index to process
t_idx_out : int
The time index to write to
c_idx : int
The channel index to process. Default is None
kwargs : dict
Additional arguments to pass to the CZYX function
"""

# TODO: temporary fix to slumkit issue
if _is_nested(input_channel_indices):
input_channel_indices = [
int(x) for x in input_channel_indices if x.isdigit()
]
if _is_nested(output_channel_indices):
output_channel_indices = [
int(x) for x in output_channel_indices if x.isdigit()
]

# Check if t_idx should be added to the func kwargs
# This is needed when a different processing is needed for each time point, for example during stabilization
all_func_params = inspect.signature(func).parameters.keys()
if "t_idx" in all_func_params:
kwargs["t_idx"] = t_idx

# Process CZYX vs ZYX
if input_channel_indices is not None and len(input_channel_indices) > 0:
click.echo(f"Processing t={t_idx}")

czyx_data = position.data.oindex[t_idx, input_channel_indices]
if not _check_nan_n_zeros(czyx_data):
transformed_czyx = func(czyx_data, **kwargs)
# Write to file
with open_ome_zarr(output_path, mode="r+") as output_dataset:
output_dataset[0].oindex[
t_idx_out, output_channel_indices
] = transformed_czyx
click.echo(f"Finished Writing.. t={t_idx}")
else:
click.echo(f"Skipping t={t_idx} due to all zeros or nans")
else:
click.echo(f"Processing c={c_idx}, t={t_idx}")

czyx_data = position.data.oindex[t_idx, c_idx : c_idx + 1]
# Checking if nans or zeros and skip processing
if not _check_nan_n_zeros(czyx_data):
# Apply transformation
transformed_czyx = func(czyx_data, **kwargs)

# Write to file
with open_ome_zarr(output_path, mode="r+") as output_dataset:
output_dataset[0][
t_idx_out, c_idx : c_idx + 1
] = transformed_czyx

click.echo(f"Finished Writing.. c={c_idx}, t={t_idx}")
else:
click.echo(
f"Skipping c={c_idx}, t={t_idx} due to all zeros or nans"
)


# TODO: modify how we get the time and channesl like recOrder (isinstance(input, list) or instance(input,int) or all)
def process_single_position_v2(
func,
input_data_path: Path,
output_path: Path,
time_indices: list = [0],
time_indices_out: list = [0],
input_channel_idx: list = [],
output_channel_idx: list = [],
num_processes: int = mp.cpu_count(),
**kwargs,
) -> None:
"""
Register a single position with multiprocessing parallelization over T and C
Parameters
----------
func : CZYX -> CZYX function
The function to be applied to the data
input_data_path : Path
The path to input position
output_path : Path
The path to output OME-Zarr Store
time_indices : list
The time indices to process.
Must match time_indices_out if not "all"
time_indices_out : list
The time indices to write to.
Must match time_indices if not "all"
input_channel_idx : list
The channel indices to process.
If empty list, process all channels.
Must match output_channel_idx if not empty
output_channel_idx : list
The channel indices to write to.
If empty list, write to all channels.
Must match input_channel_idx if not empty
num_processes : int
Number of simulatenous processes per position
kwargs : dict
Additional arguments to pass to the CZYX function
"""
# Function to be applied
click.echo(f"Function to be applied: \t{func}")

# Get the reader and writer
click.echo(f"Input data path:\t{input_data_path}")
click.echo(f"Output data path:\t{str(output_path)}")
input_dataset = open_ome_zarr(str(input_data_path))
stdout_buffer = io.StringIO()
with contextlib.redirect_stdout(stdout_buffer):
input_dataset.print_tree()
click.echo(f" Input data tree: {stdout_buffer.getvalue()}")

# Find time indices
if time_indices == "all":
time_indices = range(input_dataset.data.shape[0])
time_indices_out = time_indices
elif isinstance(time_indices, list):
time_indices_out = range(len(time_indices))

# Check for invalid times
time_ubound = input_dataset.data.shape[0] - 1
if np.max(time_indices) > time_ubound:
raise ValueError(
f"time_indices = {time_indices} includes a time index beyond the maximum index of the dataset = {time_ubound}"
)

# Check the arguments for the function
all_func_params = inspect.signature(func).parameters.keys()
# Extract the relevant kwargs for the function 'func'
func_args = {}
non_func_args = {}

for k, v in kwargs.items():
if k in all_func_params:
func_args[k] = v
else:
non_func_args[k] = v

# Write the settings into the metadata if existing
if "extra_metadata" in non_func_args:
# For each dictionary in the nest
with open_ome_zarr(output_path, mode="r+") as output_dataset:
for params_metadata_keys in kwargs["extra_metadata"].keys():
output_dataset.zattrs["extra_metadata"] = non_func_args[
"extra_metadata"
]

# Loop through (T, C), deskewing and writing as we go
click.echo(f"\nStarting multiprocess pool with {num_processes} processes")

if input_channel_idx is None or len(input_channel_idx) == 0:
# If C is not empty, use itertools.product with both ranges
_, C, _, _, _ = input_dataset.data.shape
iterable = [
(time_idx, time_idx_out, c)
for (time_idx, time_idx_out), c in itertools.product(
zip(time_indices, time_indices_out), range(C)
)
]
partial_apply_transform_to_zyx_and_save = partial(
apply_transform_to_zyx_and_save_v2,
func,
input_dataset,
output_path / Path(*input_data_path.parts[-3:]),
input_channel_idx,
output_channel_idx,
**func_args,
)
else:
# If C is empty, use only the range for time_indices
iterable = list(zip(time_indices, time_indices_out))
partial_apply_transform_to_zyx_and_save = partial(
apply_transform_to_zyx_and_save_v2,
func,
input_dataset,
output_path / Path(*input_data_path.parts[-3:]),
input_channel_idx,
output_channel_idx,
c_idx=0,
**func_args,
)

click.echo(f"\nStarting multiprocess pool with {num_processes} processes")
with mp.Pool(num_processes) as p:
p.starmap(
partial_apply_transform_to_zyx_and_save,
iterable,
)


def _is_nested(lst):
"""
Check if the list is nested or not.
NOTE: this function was created for a bug in slumkit that nested input_channel_indices into a list of lists
TODO: check if this is still an issue in slumkit
"""
return any(isinstance(i, list) for i in lst) or any(
isinstance(i, str) for i in lst
)


def _check_nan_n_zeros(input_array):
"""
Checks if any of the channels are all zeros or nans and returns true
"""
if len(input_array.shape) == 3:
# Check if all the values are zeros or nans
if np.all(input_array == 0) or np.all(np.isnan(input_array)):
# Return true
return True
elif len(input_array.shape) == 4:
# Get the number of channels
num_channels = input_array.shape[0]
# Loop through the channels
for c in range(num_channels):
# Get the channel
zyx_array = input_array[c, :, :, :]

# Check if all the values are zeros or nans
if np.all(zyx_array == 0) or np.all(np.isnan(zyx_array)):
# Return true
return True
else:
raise ValueError("Input array must be 3D or 4D")

# Return false
return False

0 comments on commit 16d2b83

Please sign in to comment.