Skip to content

Commit

Permalink
add upsample support
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Dec 10, 2024
1 parent b98cece commit e7ca981
Showing 1 changed file with 36 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from xarray_multiscale import windowed_mean
import numpy as np
import dask.array as da
from skimage.transform import rescale

from typing import Sequence

Expand Down Expand Up @@ -62,20 +63,52 @@ def preprocess(self, array: Array) -> Array:
Preprocess an array by resampling it to the desired voxel size.
"""
if self.downsample is not None:
downsample = Coordinate(self.downsample)
downsample = list(self.downsample)
for i, axis_name in enumerate(array.axis_names):
if "^" in axis_name:
downsample = downsample[:i] + [1] + downsample[i:]
return Array(
data=downscale_dask(
adjust_shape(array.data, downsample),
windowed_mean,
scale_factors=downsample,
),
offset=array.offset,
voxel_size=array.voxel_size * downsample,
voxel_size=array.voxel_size * self.downsample,
axis_names=array.axis_names,
units=array.units,
)
elif self.upsample is not None:
raise NotImplementedError("Upsampling not yet implemented")
upsample = list(self.upsample)
for i, axis_name in enumerate(array.axis_names):
if "^" in axis_name:
upsample = upsample[:i] + [1] + upsample[i:]

depth = [int(x > 1) for x in upsample]
trim_slicing = tuple(
slice(d * s, (-d * s)) if d > 1 else slice(None)
for d, s in zip(depth, upsample)
)

rescaled_arr = da.map_overlap(
lambda x: rescale(
x, upsample, order=int(self.interp_order), preserve_range=True
)[trim_slicing],
array.data,
depth=depth,
boundary="reflect",
trim=False,
dtype=array.data.dtype,
chunks=tuple(c * u for c, u in zip(array.data.chunksize, upsample)),
)

return Array(
data=rescaled_arr,
offset=array.offset,
voxel_size=array.voxel_size / self.upsample,
axis_names=array.axis_names,
units=array.units,
)

def array(self, mode: str = "r") -> Array:
source_array = self.source_array_config.array(mode)
Expand Down

0 comments on commit e7ca981

Please sign in to comment.