From 737489fc61c03e05ab4865b0f97491b5751e9f62 Mon Sep 17 00:00:00 2001 From: AlbertDominguez Date: Tue, 12 Nov 2024 15:51:36 +0100 Subject: [PATCH] add subpix aggregation option in cli, improve threshold opt --- spotiflow/cli/predict.py | 13 +++++++++++-- spotiflow/model/spotiflow.py | 4 ++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/spotiflow/cli/predict.py b/spotiflow/cli/predict.py index 3d93a92..fbe2e16 100644 --- a/spotiflow/cli/predict.py +++ b/spotiflow/cli/predict.py @@ -41,7 +41,7 @@ def get_args(): required.add_argument( "data_path", type=Path, - help=f"Path to image file or directory of image files. If a directory, will process all images in the directory.", + help="Path to image file or directory of image files. If a directory, will process all images in the directory.", ) required.add_argument( "-pm", @@ -128,6 +128,13 @@ def get_args(): default=True, help="Whether to use the stereographic flow to compute subpixel localization. If None, will deduce from the model configuration. Defaults to True.", ) + parser.add_argument( + "-spr", + "--subpix-radius", + type=int, + default=0, + help="Radius of the flow region to consider around the heatmap peak. Defaults to 0 (no aggregation).", + ) predict.add_argument( "-p", "--peak-mode", @@ -284,6 +291,8 @@ def main(): if args.verbose: log.info(f"Predicting spots in {fname} with {n_tiles=}") + _subpix_arg = False if not args.subpix else args.subpix_radius + spots, details = model.predict( img, prob_thresh=args.probability_threshold, @@ -291,7 +300,7 @@ def main(): min_distance=args.min_distance, exclude_border=args.exclude_border, scale=args.scale, - subpix=args.subpix, + subpix=_subpix_arg, peak_mode=args.peak_mode, normalizer=args.normalizer, verbose=args.verbose, diff --git a/spotiflow/model/spotiflow.py b/spotiflow/model/spotiflow.py index f19cfb3..1bd588a 100644 --- a/spotiflow/model/spotiflow.py +++ b/spotiflow/model/spotiflow.py @@ -1360,8 +1360,8 @@ def _metric_at_threshold(thr, class_label: int = 0): ] if subpix: val_pred_pts = [ - pts + _subpix[class_label][tuple(pts.astype(int).T)] - for pts, _subpix in zip(val_pred_pts, val_flow_preds) + pts + subpixel_offset(pts, np.squeeze(_subpix), np.squeeze(hmap), radius=0) + for hmap, pts, _subpix in zip(val_hm_preds, val_pred_pts, val_flow_preds) ] if self.config.is_3d and any(s > 1 for s in self.config.grid):