Skip to content

Commit

Permalink
fix non-tiled prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Jul 3, 2024
1 parent a27b0c6 commit eb96650
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions spotiflow/model/spotiflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,8 @@ def predict(
Returns:
Tuple[np.ndarray, SimpleNamespace]: Tuple of (points, details). Points are the coordinates of the spots. Details is a namespace containing the spot-wise probabilities, the heatmap and the 2D flow field.
"""
if self.config.out_channels > 1:
raise NotImplementedError("Predicting with multiple channels is not supported yet.")

skip_details = isinstance(img, da.Array) # Avoid computing details for non-NumPy inputs, which are assumed to be large

Expand Down Expand Up @@ -801,7 +803,7 @@ def predict(

ys = np.empty((self.config.out_channels,)+out_shape, np.float32) # C'HW

pts = []
points = []
probs = []
for cl in range(self.config.out_channels):
ys[cl] = center_crop(y[cl], out_shape)
Expand All @@ -812,11 +814,23 @@ def predict(
mode=peak_mode,
min_distance=min_distance,
)

curr_probs = ys[cl][tuple(curr_pts.astype(int).T)].tolist()
pts.append(curr_pts)
if subpix_radius >= 0:
subpix_tile = flow_to_vector(flow[cl], sigma=self.config.sigma)
_offset = subpixel_offset(curr_pts, subpix_tile, ys[cl], radius=subpix_radius)
curr_pts = curr_pts + _offset


points.append(curr_pts)
probs.append(curr_probs)

# ! FIXME: This is a temporary fix which will stop working when multi-channel output is implemented
points = points[0]
probs = probs[0]
y = ys[0]
if subpix_radius >= 0:
_subpix = _subpix[0]
flow = flow[0]

else: # Predict with tiling
if self.config.out_channels > 1:
Expand Down Expand Up @@ -950,20 +964,19 @@ def predict(

points = np.concatenate(points, axis=0)

# Remove padding
# Remove padding
padding_to_correct = (padding[0][0], padding[1][0])
if self.config.is_3d:
padding_to_correct = (*padding_to_correct, padding[2][0])
points = points - np.array(padding_to_correct)[None]/corr_grid
probs = np.asarray(probs)
# if scale is not None and scale != 1:
# points = np.round((points.astype(float) / scale)).astype(int)
probs = np.asarray(probs)
# if scale is not None and scale != 1:
# points = np.round((points.astype(float) / scale)).astype(int)
probs = filter_shape(probs, out_shape, idxr_array=points)
pts = filter_shape(points, out_shape, idxr_array=points)

probs = filter_shape(probs, out_shape, idxr_array=points)
pts = filter_shape(points, out_shape, idxr_array=points)

if self.config.is_3d and any(s>1 for s in self.config.grid):
pts *= np.asarray(self.config.grid)
if self.config.is_3d and any(s>1 for s in self.config.grid):
pts *= np.asarray(self.config.grid)

if skip_details:
y = None
Expand All @@ -976,7 +989,6 @@ def predict(
if verbose:
log.info(f"Found {len(pts)} spots")


details = SimpleNamespace(prob=probs, heatmap=y, subpix=_subpix, flow=flow)
return pts, details

Expand Down

0 comments on commit eb96650

Please sign in to comment.