diff --git a/README.md b/README.md index b350440..76ffa35 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ # Spotiflow - accurate and efficient spot detection with stereographic flow -*Spotiflow* is a deep learning-based, threshold-agnostic, and subpixel-accurate spot detection method for fluorescence microscopy. It is primarily developed for spatial transcriptomics workflows that require transcript detection in large, multiplexed FISH-images, although it can also be used to detect spot-like structures in general fluorescence microscopy images. A more detailed description of the method can be found in [our paper](https://doi.org/10.1101/2024.02.01.578426). +*Spotiflow* is a deep learning-based, threshold-agnostic, subpixel-accurate 2D and 3D spot detection method for fluorescence microscopy. It is primarily developed for spatial transcriptomics workflows that require transcript detection in large, multiplexed FISH-images, although it can also be used to detect spot-like structures in general fluorescence microscopy images and volumes. A more detailed description of the method can be found in [our paper](https://doi.org/10.1101/2024.02.01.578426). ![Overview](artwork/overview.png) @@ -46,9 +46,12 @@ pip install spotiflow ## Usage -### Training +### Training (2D images) See the [example training script](scripts/train.py) or the [example notebook](examples/1_train.ipynb) for an example of training. For finetuning an already pretrained model, please refer to the [finetuning example notebook](examples/3_finetune.ipynb). +### Training (3D volumes) +See the [example 3D training script](scripts/train_simple_3d.py). For more information, please refer to the [3D training example notebook](examples/4_train_3d.ipynb). Fine-tuning a 3D model can be done by following the same workflow as to the 2D case. + ### Inference (CLI) You can use the CLI to run inference on an image or folder containing several images. To do that, you can use the following command: @@ -61,7 +64,7 @@ where PATH can be either an image or a folder. By default, the command will use ### Inference (API) -The API allows detecting spots in a new image in a few lines of code! Please check the [corresponding example notebook](examples/2_inference.ipynb) and the documentation for a more in-depth explanation. +The API allows detecting spots in a new image in a few lines of code! Please check the [corresponding example notebook](examples/2_inference.ipynb) and the documentation for a more in-depth explanation. The same procedure can be followed for 3D volumes. ```python from spotiflow.model import Spotiflow @@ -82,7 +85,7 @@ points, details = model.predict(img) # points contains the coordinates of the de ``` ### Napari plugin -Our napari plugin allows detecting spots directly with an easy-to-use UI. See [napari-spotiflow](https://github.com/weigertlab/napari-spotiflow) for more information. +Our napari plugin allows detecting spots in 2D and 3D directly with an easy-to-use UI. See [napari-spotiflow](https://github.com/weigertlab/napari-spotiflow) for more information. ## For developers @@ -128,13 +131,13 @@ If you use this code in your research, please cite [the Spotiflow paper](https:/ ```bibtex @article {dominguezmantes24, - author = {Albert Dominguez Mantes and Antonio Herrera and Irina Khven and Anjalie Schlaeppi and Eftychia Kyriacou and Georgios Tsissios and Can Aztekin and Joachim Ligner and Gioele La Manno and Martin Weigert}, - title = {Spotiflow: accurate and efficient spot detection for imaging-based spatial transcriptomics with stereographic flow regression}, + author = {Dominguez Mantes, Albert and Herrera, Antonio and Khven, Irina and Schlaeppi, Anjalie and Kyriacou, Eftychia and Tsissios, Georgios and Skoufa, Evangelia and Santangeli, Luca and Buglakova, Elena and Durmus, Emine Berna and Manley, Suliana and Kreshuk, Anna and Arendt, Detlev and Aztekin, Can and Lingner, Joachim and La Manno, Gioele and Weigert, Martin}, + title = {Spotiflow: accurate and efficient spot detection for fluorescence microscopy with deep stereographic flow regression}, elocation-id = {2024.02.01.578426}, year = {2024}, doi = {10.1101/2024.02.01.578426}, publisher = {Cold Spring Harbor Laboratory}, - URL = {https://www.biorxiv.org/content/early/2024/02/05/2024.02.01.578426}, + URL = {https://www.biorxiv.org/content/early/2024/02/05/2024.02.01.578426}, eprint = {https://www.biorxiv.org/content/early/2024/02/05/2024.02.01.578426.full.pdf}, journal = {bioRxiv} } diff --git a/examples/4_train_3d.ipynb b/examples/4_train_3d.ipynb new file mode 100644 index 0000000..fd66829 --- /dev/null +++ b/examples/4_train_3d.ipynb @@ -0,0 +1,242 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6dca17b2-9f03-4519-9e5a-19241ce76b73", + "metadata": { + "tags": [] + }, + "source": [ + "# Spotiflow: training your own model in 3D" + ] + }, + { + "cell_type": "markdown", + "id": "4b3ebdc4-b6a2-4a0e-b553-06ecf037cc37", + "metadata": {}, + "source": [ + "**NOTE**: this notebook requires `napari` to be installed if you want to visualize the data (optional but recommended). You can install it e.g. via `pip install napari[all]` (see [the instructions](https://napari.org/stable/tutorials/fundamentals/installation.html) if you have any issue).\n", + "\n", + "Let's first load all the libraries we're gonna need to detect spots in our volumes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2138a5b-4ed2-4c00-ab5b-0d289ffa41e1", + "metadata": {}, + "outputs": [], + "source": [ + "from spotiflow.model import Spotiflow, SpotiflowModelConfig\n", + "from spotiflow.sample_data import load_dataset\n", + "from spotiflow.utils import get_data\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "id": "98b51535-bf2f-4ebd-bcbe-40d561c9792e", + "metadata": {}, + "source": [ + "Similarly to the 2D case, we first load our dataset. We will use the `synth_3d` dataset (corresponding to `synthetic-3d` in the paper, which is a good starting point if you want to then fine-tune on your own data). If you have your own annotated data, you can load it and store it in six different variables corresponding to the training images and spots, to the validation images and spots and to the test images and spots. You can use the `load_data()` function to that end (please [see the docs](https://weigertlab.github.io/spotiflow) to check the data format that the function allows)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "579a5acb-1ebc-436b-a794-23ce6e9f94d1", + "metadata": {}, + "outputs": [], + "source": [ + "trX, trY, valX, valY, testX, testY = load_dataset(\"synth_3d\", include_test=True)\n", + "# trX, trY, valX, valY, testX, testY = get_data(\"/FOLDER/WITH/DATA\", include_test=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49bd5718", + "metadata": {}, + "outputs": [], + "source": [ + "trY[0].shape" + ] + }, + { + "cell_type": "markdown", + "id": "74a14129-713d-432d-9e32-59dc1463d86d", + "metadata": {}, + "source": [ + "The first two variables should contain the training images and annotations, while the latter the validation ones. While visualizing the images in Python is quite straightforward, that is not the case for 3D volumes. We will use the `napari` library to visualize the volumes. If you don't have it installed, you can do so by checking the first cell in the notebook. The cell below won't run if you don't have `napari` installed, but you can still run the rest of the notebook without it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca5e446c-fea5-44df-a6cb-3058e0ce7985", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " import napari\n", + " viewer = napari.Viewer(ndisplay=3)\n", + " viewer.add_image(trX[0], name=\"Training volume\", colormap=\"gray\")\n", + " viewer.add_points(trY[0], name=\"Training spots\", face_color=\"orange\", edge_color=\"orange\", size=5, symbol=\"ring\")\n", + "except ImportError as _:\n", + " print(\"napari not installed, skipping visualization\")\n", + " viewer = None\n", + "except Exception as e:\n", + " raise e" + ] + }, + { + "cell_type": "markdown", + "id": "6b385ee0-323d-4efa-be41-9b9fe423017c", + "metadata": {}, + "source": [ + "Training with the default model configuration is straightforward, althought not as much as in the 2D case. First we need to instantiate the model configuration (check [the documentation](https://weigertlab.github.io/spotiflow) for more information about other options):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9e41080-6b9f-44b8-aa7f-b8f533ccb60b", + "metadata": {}, + "outputs": [], + "source": [ + "config = SpotiflowModelConfig(\n", + " is_3d=True, # 3D model\n", + " grid=(2, 2, 2), # predict on a downsampled grid, this is the value used in the paper\n", + ")\n", + "model = Spotiflow(config=config)" + ] + }, + { + "cell_type": "markdown", + "id": "580f0445-4cb6-49da-9e18-512e7e501281", + "metadata": {}, + "source": [ + "We can now train the model with calling `.fit()` after setting where we want the model to be stored. Again, you need to define the training parameters. If you want to change some values (_e.g._ the number of epochs), simply change the parameter accordingly (for more information, check [the documentation](https://weigertlab.github.io/spotiflow)):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f37e8f49-cc73-4761-a2dc-17f181ac4d66", + "metadata": {}, + "outputs": [], + "source": [ + "save_folder = \"models/synth_3d\" # change to where you want to store the model\n", + "train_config = {\n", + " \"num_epochs\": 5,\n", + " \"crop_size\": 128,\n", + " \"crop_size_depth\": 32,\n", + " \"smart_crop\": True,\n", + "}\n", + "model.fit(\n", + " trX,\n", + " trY,\n", + " valX,\n", + " valY,\n", + " save_dir=save_folder,\n", + " train_config=train_config,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f0000786-69b0-4332-8346-c73b86bf3415", + "metadata": {}, + "source": [ + "Our model is now ready to be used! Let's first check the save folder to make sure the model was stored properly (there should be two `.pt` files (`best.pt` and `last.pt`) as well as three `.yaml` configuration files.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79da07cb-100c-4ff6-9588-a83e9b6286f2", + "metadata": {}, + "outputs": [], + "source": [ + "!ls $save_folder" + ] + }, + { + "cell_type": "markdown", + "id": "55ca3e53-89d6-459b-bf97-af871b34fcab", + "metadata": {}, + "source": [ + "We can also quickly predict on a test image which was not seen during training (see [the inference notebook](./2_inference.ipynb) for more information about predicting as well as model loading): " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1972096-aa37-40a1-b90b-900c61202f0c", + "metadata": {}, + "outputs": [], + "source": [ + "test_pred, _ = model.predict(testX[0], device=\"auto\")" + ] + }, + { + "cell_type": "markdown", + "id": "dd4a6108", + "metadata": {}, + "source": [ + "Let's visualize the results now using `napari` (if it is already running):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68c2682e", + "metadata": {}, + "outputs": [], + "source": [ + "if viewer is not None:\n", + " for layer in viewer.layers:\n", + " viewer.layers.remove(layer)\n", + " viewer.add_image(testX[0], name=\"Test volume\", colormap=\"gray\")\n", + " viewer.add_points(testY[0], name=\"Predicted test spots\", face_color=\"orange\", edge_color=\"orange\", size=5, symbol=\"ring\")" + ] + }, + { + "cell_type": "markdown", + "id": "60ce51f9-ef83-4b2e-9a70-57f4a9430487", + "metadata": {}, + "source": [ + "This notebook shows the most user-friendly way to train models. If you want to dive deeper into the model architecture and tweak the code and you are already comfortable with training DL models, please check [the documentation](https://weigertlab.github.io/spotiflow) to get started." + ] + }, + { + "cell_type": "markdown", + "id": "a6d34ee6", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cbaidt", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/spotiflow/model/config.py b/spotiflow/model/config.py index 4803463..83e7c08 100644 --- a/spotiflow/model/config.py +++ b/spotiflow/model/config.py @@ -1,14 +1,15 @@ import abc import argparse -import logging import json -import yaml +import logging import sys - from numbers import Number from pathlib import Path from typing import Literal, Optional, Tuple, Union +import numpy as np +import yaml + logging.basicConfig(level=logging.INFO, stream=sys.stdout) log = logging.getLogger(__name__) @@ -144,6 +145,8 @@ def __init__( self.is_3d = bool(is_3d) if isinstance(grid, int): self.grid = (grid,)*(3 if self.is_3d else 2) + elif not isinstance(grid, tuple): + self.grid = tuple(grid) else: self.grid = grid @@ -163,8 +166,8 @@ def is_valid(self): isinstance(self.in_channels, int) and self.in_channels > 0 ), "in_channels must be greater than 0" assert ( - isinstance(self.out_channels, int) and self.out_channels > 0 - ), "out_channels must be greater than 0" + isinstance(self.out_channels, int) and self.out_channels == 1 + ), "out_channels must be equal to 1 (multi-channel output not supported yet)" assert ( isinstance(self.initial_fmaps, int) and self.initial_fmaps > 0 ), "initial_fmaps must be greater than 0" @@ -212,6 +215,7 @@ def is_valid(self): assert ( all(isinstance(s, int) and s > 0 and (s == 1 or s % 2 == 0) for s in self.grid) ), "grid must be a tuple containing only 1 or even positive integers" + assert len(np.unique(self.grid)) == 1, "grid must currently be isotropic (all dimensions must be equal)" class SpotiflowTrainingConfig(SpotiflowConfig): diff --git a/spotiflow/model/pretrained.py b/spotiflow/model/pretrained.py index 9fc6a16..2fd9463 100644 --- a/spotiflow/model/pretrained.py +++ b/spotiflow/model/pretrained.py @@ -15,7 +15,7 @@ class RegisteredModel: url: str md5_hash: str - + is_3d: bool def list_registered(): return list(_REGISTERED.keys()) @@ -50,13 +50,21 @@ def get_pretrained_model_path(name: str): "hybiss": RegisteredModel( url="https://drive.switch.ch/index.php/s/O4hqFSSGX6veLwa/download", md5_hash="254afa97c137d0bd74fd9c1827f0e323", + is_3d=False, ), "general": RegisteredModel( url="https://drive.switch.ch/index.php/s/6AoTEgpIAeQMRvX/download", - md5_hash="9dd31a36b737204e91b040515e3d899e" + md5_hash="9dd31a36b737204e91b040515e3d899e", + is_3d=False, ), "synth_complex": RegisteredModel( url="https://drive.switch.ch/index.php/s/CiCjNJaJzpVVD2M/download", - md5_hash="d692fa21da47e4a50b4c52f49442508b" + md5_hash="d692fa21da47e4a50b4c52f49442508b", + is_3d=False, ), + "smfish_3d": RegisteredModel( + url="https://drive.switch.ch/index.php/s/Vym7tqiORZOP5Zt/download", + md5_hash="c5ab30ba3b9ccb07b4c34442d1b5b615", + is_3d=True, + ) } diff --git a/spotiflow/model/spotiflow.py b/spotiflow/model/spotiflow.py index 803a8f6..1df460b 100644 --- a/spotiflow/model/spotiflow.py +++ b/spotiflow/model/spotiflow.py @@ -824,6 +824,7 @@ def predict( points.append(curr_pts) probs.append(curr_probs) + assert self.config.out_channels == 1, "Trying to predict using a multi-channel network, which is not supported yet." # ! FIXME: This is a temporary fix which will stop working when multi-channel output is implemented points = points[0] probs = probs[0] @@ -833,8 +834,6 @@ def predict( flow = flow[0] else: # Predict with tiling - if self.config.out_channels > 1: - raise NotImplementedError("Tiled prediction not implemented for multi-channel output yet.") padded_shape = tuple(np.array(x.shape[:actual_n_dims])//corr_grid) if not skip_details: y = np.empty(padded_shape, np.float32) diff --git a/spotiflow/sample_data/datasets.py b/spotiflow/sample_data/datasets.py index f4ffc73..9b62f86 100644 --- a/spotiflow/sample_data/datasets.py +++ b/spotiflow/sample_data/datasets.py @@ -17,6 +17,7 @@ class RegisteredDataset: url: str md5_hash: str + is_3d: bool def list_registered(): @@ -56,17 +57,27 @@ def load_dataset(name: str, include_test: bool=False): name (str): the name of the dataset to load. include_test (bool, optional): whether to include the test set in the returned data. Defaults to False. """ + if name not in _REGISTERED: + raise NotRegisteredError(f"No training dataset named {name} found. Available datasets: {','.join(sorted(list_registered()))}") + dataset = _REGISTERED[name] path = get_training_datasets_path(name) - return get_data(path, include_test=include_test) + return get_data(path, include_test=include_test, is_3d=dataset.is_3d) _REGISTERED = { "synth_complex": RegisteredDataset( url="https://drive.switch.ch/index.php/s/aWdxUHULLkLLtqS/download", md5_hash="5f44b03603fe1733ac0f2340a69ae238", + is_3d=False, ), "merfish": RegisteredDataset( url="https://drive.switch.ch/index.php/s/fsjOypn4ICpSF2w/download", md5_hash="17fcdbd12cc71630e4f49652ded837c7", + is_3d=False, + ), + "synth_3d": RegisteredDataset( + url="https://drive.switch.ch/index.php/s/EemgJK1Bno8c3n4/download", + md5_hash="f1715515763288362ee3351caca02825", + is_3d=True, ), } diff --git a/tests/test_peaks.py b/tests/test_peaks.py index cbf4552..e26c446 100644 --- a/tests/test_peaks.py +++ b/tests/test_peaks.py @@ -1,49 +1,56 @@ -import numpy as np from types import SimpleNamespace -from spotiflow.utils import points_to_prob, points_to_flow, flow_to_vector, local_peaks, points_from_heatmap_flow +from typing import Tuple + +import numpy as np +import pytest +from spotiflow.utils import points_from_heatmap_flow, points_to_flow, points_to_prob from spotiflow.utils.matching import points_matching -def round_trip(points:np.ndarray, grid:list[int], sigma=1.5): + +def round_trip(points:np.ndarray, grid: Tuple[int], sigma: float=1.5): """ Test round trip of points through the flow field and back""" points = np.asarray(points) ndim = points.shape[1] shape = (2*int(points.max()),)*ndim # get heatmap and flow - heatmap=points_to_prob(points, shape=shape, sigma=sigma, mode="max", grid=grid) + heatmap = points_to_prob(points, shape=shape, sigma=sigma, mode="max", grid=grid) flow = points_to_flow(points, shape, sigma=sigma, grid=grid) points_new = points_from_heatmap_flow(heatmap, flow, sigma=sigma, grid=grid) return SimpleNamespace(points=points, points_new=points_new, heatmap=heatmap, flow=flow, sigma=sigma) - +@pytest.mark.parametrize("ndim", (2, 3)) +@pytest.mark.parametrize("grid", (None, (2, 2))) def test_prob_flow_roundtrip(ndim, grid, debug:bool=False): points = np.stack(np.meshgrid(*tuple(np.linspace(10,48,4) for _ in range(ndim)), indexing="ij"), axis=-1).reshape(-1, ndim) - points = points + np.random.uniform(-1, 1, points.shape) - out = round_trip(points, grid=grid) - - diff = points_matching(out.points, out.points_new).mean_dist + if ndim == 3 and grid is not None: + grid = (*grid, 2) - print(f"Max diff: {diff:4f}") - if debug: - import napari - v = napari.Viewer() - v.add_points(out.points, name="points", size=5, face_color="green") - v.add_points(out.points_new, name="points_new", size=5, face_color="red") - else: - assert diff < 1e-3, f"Max diff: {diff:4f}" + points = points + np.random.uniform(-1, 1, points.shape) + if ndim == 2 and (grid is not None or (isinstance(grid, tuple) and any(g > 1 for g in grid))): + with pytest.raises(NotImplementedError): + _ = round_trip(points, grid=grid) + else: + out = round_trip(points, grid=grid) + + diff = points_matching(out.points, out.points_new).mean_dist + + print(f"Max diff: {diff:4f}") + if debug: + import napari + v = napari.Viewer() + v.add_points(out.points, name="points", size=5, face_color="green") + v.add_points(out.points_new, name="points_new", size=5, face_color="red") + else: + assert diff < 1e-3, f"Max diff: {diff:4f}" - return out if __name__ == "__main__": # works - out = test_prob_flow_roundtrip(ndim=3, grid=None) + test_prob_flow_roundtrip(ndim=3, grid=None) # works - out = test_prob_flow_roundtrip(ndim=3, grid=(2,2,2)) - - # doesnt work - out = test_prob_flow_roundtrip(ndim=3, grid=(1,2,2)) - + test_prob_flow_roundtrip(ndim=3, grid=(2,2,2))