diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py index e2889d01..d60a5c3d 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py @@ -9,6 +9,7 @@ from xarray_multiscale import windowed_mean import numpy as np import dask.array as da +from skimage.transform import rescale from typing import Sequence @@ -62,7 +63,10 @@ def preprocess(self, array: Array) -> Array: Preprocess an array by resampling it to the desired voxel size. """ if self.downsample is not None: - downsample = Coordinate(self.downsample) + downsample = list(self.downsample) + for i, axis_name in enumerate(array.axis_names): + if "^" in axis_name: + downsample = downsample[:i] + [1] + downsample[i:] return Array( data=downscale_dask( adjust_shape(array.data, downsample), @@ -70,12 +74,41 @@ def preprocess(self, array: Array) -> Array: scale_factors=downsample, ), offset=array.offset, - voxel_size=array.voxel_size * downsample, + voxel_size=array.voxel_size * self.downsample, axis_names=array.axis_names, units=array.units, ) elif self.upsample is not None: - raise NotImplementedError("Upsampling not yet implemented") + upsample = list(self.upsample) + for i, axis_name in enumerate(array.axis_names): + if "^" in axis_name: + upsample = upsample[:i] + [1] + upsample[i:] + + depth = [int(x > 1) for x in upsample] + trim_slicing = tuple( + slice(d * s, (-d * s)) if d > 1 else slice(None) + for d, s in zip(depth, upsample) + ) + + rescaled_arr = da.map_overlap( + lambda x: rescale( + x, upsample, order=int(self.interp_order), preserve_range=True + )[trim_slicing], + array.data, + depth=depth, + boundary="reflect", + trim=False, + dtype=array.data.dtype, + chunks=tuple(c * u for c, u in zip(array.data.chunksize, upsample)), + ) + + return Array( + data=rescaled_arr, + offset=array.offset, + voxel_size=array.voxel_size / self.upsample, + axis_names=array.axis_names, + units=array.units, + ) def array(self, mode: str = "r") -> Array: source_array = self.source_array_config.array(mode)