Skip to content

Commit

Permalink
add subpix aggregation option in cli, improve threshold opt
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Nov 12, 2024
1 parent 6e472e3 commit 737489f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
13 changes: 11 additions & 2 deletions spotiflow/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_args():
required.add_argument(
"data_path",
type=Path,
help=f"Path to image file or directory of image files. If a directory, will process all images in the directory.",
help="Path to image file or directory of image files. If a directory, will process all images in the directory.",
)
required.add_argument(
"-pm",
Expand Down Expand Up @@ -128,6 +128,13 @@ def get_args():
default=True,
help="Whether to use the stereographic flow to compute subpixel localization. If None, will deduce from the model configuration. Defaults to True.",
)
parser.add_argument(
"-spr",
"--subpix-radius",
type=int,
default=0,
help="Radius of the flow region to consider around the heatmap peak. Defaults to 0 (no aggregation).",
)
predict.add_argument(
"-p",
"--peak-mode",
Expand Down Expand Up @@ -284,14 +291,16 @@ def main():
if args.verbose:
log.info(f"Predicting spots in {fname} with {n_tiles=}")

_subpix_arg = False if not args.subpix else args.subpix_radius

spots, details = model.predict(
img,
prob_thresh=args.probability_threshold,
n_tiles=n_tiles,
min_distance=args.min_distance,
exclude_border=args.exclude_border,
scale=args.scale,
subpix=args.subpix,
subpix=_subpix_arg,
peak_mode=args.peak_mode,
normalizer=args.normalizer,
verbose=args.verbose,
Expand Down
4 changes: 2 additions & 2 deletions spotiflow/model/spotiflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,8 +1360,8 @@ def _metric_at_threshold(thr, class_label: int = 0):
]
if subpix:
val_pred_pts = [
pts + _subpix[class_label][tuple(pts.astype(int).T)]
for pts, _subpix in zip(val_pred_pts, val_flow_preds)
pts + subpixel_offset(pts, np.squeeze(_subpix), np.squeeze(hmap), radius=0)
for hmap, pts, _subpix in zip(val_hm_preds, val_pred_pts, val_flow_preds)
]

if self.config.is_3d and any(s > 1 for s in self.config.grid):
Expand Down

0 comments on commit 737489f

Please sign in to comment.