Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend fitting function to 3D and add to model predict #20

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions spotiflow/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .. import __version__
from ..model import Spotiflow
from ..utils import estimate_params, infer_n_tiles, str2bool
from ..utils import infer_n_tiles, str2bool
from ..utils.fitting import signal_to_background

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -138,10 +138,10 @@ def get_args():
help="Peak detection mode (can be either 'skimage' or 'fast', which is a faster custom C++ implementation). Defaults to 'fast'.",
)
predict.add_argument(
"--estimate-fwhm",
"--estimate-params",
type=str2bool,
default=False,
help="Estimate FWHM of detected spots by Gaussian fitting. Defaults to False.",
help="Estimate fit parameters of detected spots by Gaussian fitting (eg FWHM, intensity). Defaults to False.",
)
predict.add_argument(
"-norm",
Expand Down Expand Up @@ -296,24 +296,19 @@ def main():
normalizer=args.normalizer,
verbose=args.verbose,
device=args.device,
fit_params=args.estimate_params,
)
csv_columns = ("y", "x")
if spots.shape[1] == 3:
csv_columns = ("z",) + csv_columns
df = pd.DataFrame(np.round(spots, 4), columns=csv_columns)
df["intensity"] = np.round(details.intens, 2)
df["probability"] = np.round(details.prob, 3)
if args.estimate_fwhm:
if spots.shape[1] == 3:
log.warning(
"Estimating FWHM is not supported for 3D images yet. Skipping."
)
else:
params = estimate_params(img, spots)
df['fwhm'] = np.round(params.fwhm, 3)
df['intens_A'] = np.round(params.intens_A, 3)
df['intens_B'] = np.round(params.intens_B, 3)
df['snb'] = np.round(signal_to_background(params), 3)
if args.estimate_params:
df['fwhm'] = np.round(details.fit_params.fwhm, 3)
df['intens_A'] = np.round(details.fit_params.intens_A, 3)
df['intens_B'] = np.round(details.fit_params.intens_B, 3)
df['snb'] = np.round(signal_to_background(details.fit_params), 3)

df.to_csv(out_dir / f"{fname.stem}.csv", index=False)
return 0
Expand Down
13 changes: 8 additions & 5 deletions spotiflow/lib/spotflow3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,12 @@ static PyObject *c_gaussian3d(PyObject *self, PyObject *args)

PyArrayObject *points = NULL;
PyArrayObject *dst = NULL;
PyArrayObject *sigmas = NULL;
PyArrayObject *probs = NULL;
int shape_z, shape_y, shape_x;
int grid_z, grid_y, grid_x;
float sigma;

if (!PyArg_ParseTuple(args, "O!iiiiiif", &PyArray_Type, &points, &shape_z, &shape_y, &shape_x, &grid_z, &grid_y, &grid_x, &sigma))
if (!PyArg_ParseTuple(args, "O!O!O!iiiiii", &PyArray_Type, &points, &PyArray_Type, &probs, &PyArray_Type, &sigmas, &shape_z, &shape_y, &shape_x, &grid_z, &grid_y, &grid_x))
return NULL;

npy_intp *dims = PyArray_DIMS(points);
Expand Down Expand Up @@ -204,8 +205,6 @@ static PyObject *c_gaussian3d(PyObject *self, PyObject *args)
index.buildIndex();


const float sigma_denom = 2 * sigma * sigma / cbrt(grid_z * grid_y * grid_x);

#ifdef __APPLE__
#pragma omp parallel for
#else
Expand Down Expand Up @@ -237,8 +236,12 @@ static PyObject *c_gaussian3d(PyObject *self, PyObject *args)

const float r2 = x * x + y * y + z * z;

const float prob = *(float *)PyArray_GETPTR1(probs, ret_index);
const float sigma = *(float *)PyArray_GETPTR1(sigmas, ret_index);
const float sigma_denom = 2 * sigma * sigma / cbrt(grid_z * grid_y * grid_x);

// the gaussian value
const float val = exp(-r2 / sigma_denom);
const float val = prob * exp(-r2 / sigma_denom);

*(float *)PyArray_GETPTR3(dst, i, j, k) = val;
}
Expand Down
14 changes: 11 additions & 3 deletions spotiflow/model/spotiflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
prob_to_points,
subpixel_offset,
trilinear_interp_points,
estimate_params
)
from ..utils import (
tile_iterator as parallel_tile_iterator,
Expand Down Expand Up @@ -714,6 +715,7 @@ def predict(
Union[torch.device, Literal["auto", "cpu", "cuda", "mps"]]
] = None,
distributed_params: Optional[dict] = None,
fit_params: bool = False,
) -> Tuple[np.ndarray, SimpleNamespace]:
"""Predict spots in an image.

Expand All @@ -730,7 +732,7 @@ def predict(
verbose (bool, optional): Whether to print logs and progress. Defaults to True.
progress_bar_wrapper (Optional[callable], optional): Progress bar wrapper to use. Defaults to None.
device (Optional[Union[torch.device, Literal["auto", "cpu", "cuda", "mps"]]], optional): computing device to use. If None, will infer from model location. If "auto", will infer from available hardware. Defaults to None.

fit_params (bool, optional): Whether to fit the model parameters to the input image. Defaults to False.
Returns:
Tuple[np.ndarray, SimpleNamespace]: Tuple of (points, details). Points are the coordinates of the spots. Details is a namespace containing the spot-wise probabilities (`prob`), the heatmap (`heatmap`), the stereographic flow (`flow`), the 2D local offset vector field (`subpix`) and the spot intensities (`intens`).
"""
Expand Down Expand Up @@ -1098,6 +1100,11 @@ def predict(
_subpix = None
flow = None

if not skip_details and fit_params:
fit_params = estimate_params(img[...,0], pts)
else:
fit_params = None

if verbose:
log.info(f"Found {len(pts)} spots")

Expand All @@ -1119,8 +1126,9 @@ def predict(
)
intens = img[tuple(pts.round().astype(int).T)]
details = SimpleNamespace(
prob=probs, heatmap=y, subpix=_subpix, flow=flow, intens=intens
)
prob=probs, heatmap=y, subpix=_subpix, flow=flow, intens=intens,
fit_params=fit_params
)
return pts, details

def predict_dataset(
Expand Down
124 changes: 115 additions & 9 deletions spotiflow/utils/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from scipy.optimize import curve_fit

from tqdm.auto import tqdm
from dataclasses import dataclass
AlbertDominguez marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -25,17 +26,31 @@ def _gaussian_2d(yx, y0, x0, sigma, A, B):
y, x = yx
return A * np.exp(-((y - y0) ** 2 + (x - x0) ** 2) / (2 * sigma**2)) + B

def _gaussian_3d(zyx, z0, y0, x0, sigma, A, B):
z, y, x = zyx
return A * np.exp(-((z - z0) ** 2 + (y - y0) ** 2 + (x - x0) ** 2) / (2 * sigma**2)) + B

@dataclass
class SpotParams:
class FitParams2D:
fwhm: Union[float, np.ndarray]
offset_y: Union[float, np.ndarray]
offset_x: Union[float, np.ndarray]
intens_A: Union[float, np.ndarray]
intens_B: Union[float, np.ndarray]
r_squared: Union[float, np.ndarray]

@dataclass
class FitParams3D:
fwhm: Union[float, np.ndarray]
offset_z: Union[float, np.ndarray]
offset_y: Union[float, np.ndarray]
offset_x: Union[float, np.ndarray]
intens_A: Union[float, np.ndarray]
intens_B: Union[float, np.ndarray]
r_squared: Union[float, np.ndarray]


def signal_to_background(params: SpotParams) -> np.ndarray:
def signal_to_background(params: FitParams2D) -> np.ndarray:
"""Calculates the signal to background ratio of the spots. Given a Gaussian fit
of the form A*exp(...) + B, the signal to background
ratio is computed as A/B.
Expand All @@ -52,13 +67,35 @@ def signal_to_background(params: SpotParams) -> np.ndarray:
return snb


def _r_squared(y_true, y_pred):
y_true, y_pred = np.array(y_true).ravel(), np.array(y_pred).ravel()
ss_res = np.sum((y_true - y_pred)**2)
ss_tot = np.sum((y_true - np.mean(y_true))**2)
r2 = 1 - (ss_res / ss_tot)
return r2

def _estimate_params_single(
center: np.ndarray,
image: np.ndarray,
window: int,
refine_centers: bool,
verbose: bool,
) -> SpotParams:
) -> Union[FitParams2D, FitParams3D]:

if image.ndim == 2:
return _estimate_params_single2(center, image, window, refine_centers, verbose)
elif image.ndim == 3:
return _estimate_params_single3(center, image, window, refine_centers, verbose)
else:
raise ValueError("Image must have 2 or 3 dimensions")

def _estimate_params_single2(
center: np.ndarray,
image: np.ndarray,
window: int,
refine_centers: bool,
verbose: bool,
) -> FitParams2D:
x_range = np.arange(-window, window + 1)
y_range = np.arange(-window, window + 1)
y, x = np.meshgrid(y_range, x_range, indexing="ij")
Expand Down Expand Up @@ -89,20 +126,81 @@ def _estimate_params_single(
p0=initial_guess,
bounds=(lower_bounds, upper_bounds),
)

pred = _gaussian_2d((y.ravel(), x.ravel()), *popt)
r_squared = _r_squared(region.ravel(), pred)

except Exception as _:
if verbose:
log.warning("Gaussian fit failed. Returning NaN")
mi, ma = np.nan, np.nan
popt = np.full(5, np.nan)
r_squared = 0

return SpotParams(
return FitParams2D(
fwhm=FWHM_CONSTANT * popt[2],
offset_y=popt[0],
offset_x=popt[1],
intens_A=(popt[3]+popt[4])*(ma - mi),
intens_B=popt[4] * (ma - mi) + mi,
r_squared=r_squared
)

def _estimate_params_single3(
center: np.ndarray,
image: np.ndarray,
window: int,
refine_centers: bool,
verbose: bool,
) -> FitParams3D:
z,y,x = np.meshgrid(*((np.arange(-window, window + 1),)*3), indexing="ij")

# Crop around the spot
region = image[
center[0] - window : center[0] + window + 1,
center[1] - window : center[1] + window + 1,
center[2] - window : center[2] + window + 1,
]

try:
mi, ma = np.min(region), np.max(region)
region = (region - mi) / (ma - mi)
initial_guess = (0, 0, 0, 1.5, 1, 0) # z0, y0, x0, sigma, A, B

if refine_centers:
lower_bounds = (-.5, -.5, -.5, 0.1, 0.5, -0.5) # y0, x0, sigma, A, B
upper_bounds = (.5, .5, .5, 10, 1.5, 0.5) # y0, x0, sigma, A, B
else:
lower_bounds = (-1e-6, -1e-6, -1e-6, 0.1, 0.5, -0.5)
upper_bounds = ( 1e-6, 1e-6, 1e-6, 10, 1.5, 0.5)

popt, _ = curve_fit(
_gaussian_3d,
(z.ravel(), y.ravel(), x.ravel()),
region.ravel(),
p0=initial_guess,
bounds=(lower_bounds, upper_bounds),
)

pred = _gaussian_3d((z.ravel(), y.ravel(), x.ravel()), *popt)
r_squared = _r_squared(region.ravel(), pred)

except Exception as _:
if verbose:
log.warning("Gaussian fit failed. Returning NaN")
mi, ma = np.nan, np.nan
popt = np.full(6, np.nan)
r_squared = 0

return FitParams3D(
fwhm=FWHM_CONSTANT * popt[3],
offset_z=popt[0],
offset_y=popt[1],
offset_x=popt[2],
intens_A=(popt[4]+popt[5])*(ma - mi),
intens_B=popt[5] * (ma - mi) + mi,
r_squared=r_squared
)

def estimate_params(
img: np.ndarray,
Expand Down Expand Up @@ -158,9 +256,17 @@ def estimate_params(
)
)

keys = SpotParams.__dataclass_fields__.keys()

params = SpotParams(
**dict((k, np.array([getattr(p, k) for p in params])) for k in keys)
)
if img.ndim == 2:
keys = FitParams2D.__dataclass_fields__.keys()
AlbertDominguez marked this conversation as resolved.
Show resolved Hide resolved
params = FitParams2D(
**dict((k, np.array([getattr(p, k) for p in params])) for k in keys)
)
elif img.ndim == 3:
keys = FitParams3D.__dataclass_fields__.keys()
AlbertDominguez marked this conversation as resolved.
Show resolved Hide resolved
params = FitParams3D(
**dict((k, np.array([getattr(p, k) for p in params])) for k in keys)
)
else:
raise ValueError("Image must have 2 or 3 dimensions")

return params
Loading
Loading