Skip to content

Commit

Permalink
prepare for release
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Aug 16, 2024
1 parent eb96650 commit 85d6319
Show file tree
Hide file tree
Showing 7 changed files with 316 additions and 42 deletions.
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

# Spotiflow - accurate and efficient spot detection with stereographic flow

*Spotiflow* is a deep learning-based, threshold-agnostic, and subpixel-accurate spot detection method for fluorescence microscopy. It is primarily developed for spatial transcriptomics workflows that require transcript detection in large, multiplexed FISH-images, although it can also be used to detect spot-like structures in general fluorescence microscopy images. A more detailed description of the method can be found in [our paper](https://doi.org/10.1101/2024.02.01.578426).
*Spotiflow* is a deep learning-based, threshold-agnostic, subpixel-accurate 2D and 3D spot detection method for fluorescence microscopy. It is primarily developed for spatial transcriptomics workflows that require transcript detection in large, multiplexed FISH-images, although it can also be used to detect spot-like structures in general fluorescence microscopy images and volumes. A more detailed description of the method can be found in [our paper](https://doi.org/10.1101/2024.02.01.578426).

![Overview](artwork/overview.png)

Expand Down Expand Up @@ -46,9 +46,12 @@ pip install spotiflow

## Usage

### Training
### Training (2D images)
See the [example training script](scripts/train.py) or the [example notebook](examples/1_train.ipynb) for an example of training. For finetuning an already pretrained model, please refer to the [finetuning example notebook](examples/3_finetune.ipynb).

### Training (3D volumes)
See the [example 3D training script](scripts/train_simple_3d.py). For more information, please refer to the [3D training example notebook](examples/4_train_3d.ipynb). Fine-tuning a 3D model can be done by following the same workflow as to the 2D case.

### Inference (CLI)

You can use the CLI to run inference on an image or folder containing several images. To do that, you can use the following command:
Expand All @@ -61,7 +64,7 @@ where PATH can be either an image or a folder. By default, the command will use

### Inference (API)

The API allows detecting spots in a new image in a few lines of code! Please check the [corresponding example notebook](examples/2_inference.ipynb) and the documentation for a more in-depth explanation.
The API allows detecting spots in a new image in a few lines of code! Please check the [corresponding example notebook](examples/2_inference.ipynb) and the documentation for a more in-depth explanation. The same procedure can be followed for 3D volumes.

```python
from spotiflow.model import Spotiflow
Expand All @@ -82,7 +85,7 @@ points, details = model.predict(img) # points contains the coordinates of the de
```

### Napari plugin
Our napari plugin allows detecting spots directly with an easy-to-use UI. See [napari-spotiflow](https://github.com/weigertlab/napari-spotiflow) for more information.
Our napari plugin allows detecting spots in 2D and 3D directly with an easy-to-use UI. See [napari-spotiflow](https://github.com/weigertlab/napari-spotiflow) for more information.


## For developers
Expand Down Expand Up @@ -128,13 +131,13 @@ If you use this code in your research, please cite [the Spotiflow paper](https:/

```bibtex
@article {dominguezmantes24,
author = {Albert Dominguez Mantes and Antonio Herrera and Irina Khven and Anjalie Schlaeppi and Eftychia Kyriacou and Georgios Tsissios and Can Aztekin and Joachim Ligner and Gioele La Manno and Martin Weigert},
title = {Spotiflow: accurate and efficient spot detection for imaging-based spatial transcriptomics with stereographic flow regression},
author = {Dominguez Mantes, Albert and Herrera, Antonio and Khven, Irina and Schlaeppi, Anjalie and Kyriacou, Eftychia and Tsissios, Georgios and Skoufa, Evangelia and Santangeli, Luca and Buglakova, Elena and Durmus, Emine Berna and Manley, Suliana and Kreshuk, Anna and Arendt, Detlev and Aztekin, Can and Lingner, Joachim and La Manno, Gioele and Weigert, Martin},
title = {Spotiflow: accurate and efficient spot detection for fluorescence microscopy with deep stereographic flow regression},
elocation-id = {2024.02.01.578426},
year = {2024},
doi = {10.1101/2024.02.01.578426},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2024/02/05/2024.02.01.578426},
URL = {https://www.biorxiv.org/content/early/2024/02/05/2024.02.01.578426},
eprint = {https://www.biorxiv.org/content/early/2024/02/05/2024.02.01.578426.full.pdf},
journal = {bioRxiv}
}
Expand Down
242 changes: 242 additions & 0 deletions examples/4_train_3d.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "6dca17b2-9f03-4519-9e5a-19241ce76b73",
"metadata": {
"tags": []
},
"source": [
"# Spotiflow: training your own model in 3D"
]
},
{
"cell_type": "markdown",
"id": "4b3ebdc4-b6a2-4a0e-b553-06ecf037cc37",
"metadata": {},
"source": [
"**NOTE**: this notebook requires `napari` to be installed if you want to visualize the data (optional but recommended). You can install it e.g. via `pip install napari[all]` (see [the instructions](https://napari.org/stable/tutorials/fundamentals/installation.html) if you have any issue).\n",
"\n",
"Let's first load all the libraries we're gonna need to detect spots in our volumes."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2138a5b-4ed2-4c00-ab5b-0d289ffa41e1",
"metadata": {},
"outputs": [],
"source": [
"from spotiflow.model import Spotiflow, SpotiflowModelConfig\n",
"from spotiflow.sample_data import load_dataset\n",
"from spotiflow.utils import get_data\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"id": "98b51535-bf2f-4ebd-bcbe-40d561c9792e",
"metadata": {},
"source": [
"Similarly to the 2D case, we first load our dataset. We will use the `synth_3d` dataset (corresponding to `synthetic-3d` in the paper, which is a good starting point if you want to then fine-tune on your own data). If you have your own annotated data, you can load it and store it in six different variables corresponding to the training images and spots, to the validation images and spots and to the test images and spots. You can use the `load_data()` function to that end (please [see the docs](https://weigertlab.github.io/spotiflow) to check the data format that the function allows)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "579a5acb-1ebc-436b-a794-23ce6e9f94d1",
"metadata": {},
"outputs": [],
"source": [
"trX, trY, valX, valY, testX, testY = load_dataset(\"synth_3d\", include_test=True)\n",
"# trX, trY, valX, valY, testX, testY = get_data(\"/FOLDER/WITH/DATA\", include_test=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49bd5718",
"metadata": {},
"outputs": [],
"source": [
"trY[0].shape"
]
},
{
"cell_type": "markdown",
"id": "74a14129-713d-432d-9e32-59dc1463d86d",
"metadata": {},
"source": [
"The first two variables should contain the training images and annotations, while the latter the validation ones. While visualizing the images in Python is quite straightforward, that is not the case for 3D volumes. We will use the `napari` library to visualize the volumes. If you don't have it installed, you can do so by checking the first cell in the notebook. The cell below won't run if you don't have `napari` installed, but you can still run the rest of the notebook without it."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca5e446c-fea5-44df-a6cb-3058e0ce7985",
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" import napari\n",
" viewer = napari.Viewer(ndisplay=3)\n",
" viewer.add_image(trX[0], name=\"Training volume\", colormap=\"gray\")\n",
" viewer.add_points(trY[0], name=\"Training spots\", face_color=\"orange\", edge_color=\"orange\", size=5, symbol=\"ring\")\n",
"except ImportError as _:\n",
" print(\"napari not installed, skipping visualization\")\n",
" viewer = None\n",
"except Exception as e:\n",
" raise e"
]
},
{
"cell_type": "markdown",
"id": "6b385ee0-323d-4efa-be41-9b9fe423017c",
"metadata": {},
"source": [
"Training with the default model configuration is straightforward, althought not as much as in the 2D case. First we need to instantiate the model configuration (check [the documentation](https://weigertlab.github.io/spotiflow) for more information about other options):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9e41080-6b9f-44b8-aa7f-b8f533ccb60b",
"metadata": {},
"outputs": [],
"source": [
"config = SpotiflowModelConfig(\n",
" is_3d=True, # 3D model\n",
" grid=(2, 2, 2), # predict on a downsampled grid, this is the value used in the paper\n",
")\n",
"model = Spotiflow(config=config)"
]
},
{
"cell_type": "markdown",
"id": "580f0445-4cb6-49da-9e18-512e7e501281",
"metadata": {},
"source": [
"We can now train the model with calling `.fit()` after setting where we want the model to be stored. Again, you need to define the training parameters. If you want to change some values (_e.g._ the number of epochs), simply change the parameter accordingly (for more information, check [the documentation](https://weigertlab.github.io/spotiflow)):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f37e8f49-cc73-4761-a2dc-17f181ac4d66",
"metadata": {},
"outputs": [],
"source": [
"save_folder = \"models/synth_3d\" # change to where you want to store the model\n",
"train_config = {\n",
" \"num_epochs\": 5,\n",
" \"crop_size\": 128,\n",
" \"crop_size_depth\": 32,\n",
" \"smart_crop\": True,\n",
"}\n",
"model.fit(\n",
" trX,\n",
" trY,\n",
" valX,\n",
" valY,\n",
" save_dir=save_folder,\n",
" train_config=train_config,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f0000786-69b0-4332-8346-c73b86bf3415",
"metadata": {},
"source": [
"Our model is now ready to be used! Let's first check the save folder to make sure the model was stored properly (there should be two `.pt` files (`best.pt` and `last.pt`) as well as three `.yaml` configuration files.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79da07cb-100c-4ff6-9588-a83e9b6286f2",
"metadata": {},
"outputs": [],
"source": [
"!ls $save_folder"
]
},
{
"cell_type": "markdown",
"id": "55ca3e53-89d6-459b-bf97-af871b34fcab",
"metadata": {},
"source": [
"We can also quickly predict on a test image which was not seen during training (see [the inference notebook](./2_inference.ipynb) for more information about predicting as well as model loading): "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1972096-aa37-40a1-b90b-900c61202f0c",
"metadata": {},
"outputs": [],
"source": [
"test_pred, _ = model.predict(testX[0], device=\"auto\")"
]
},
{
"cell_type": "markdown",
"id": "dd4a6108",
"metadata": {},
"source": [
"Let's visualize the results now using `napari` (if it is already running):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "68c2682e",
"metadata": {},
"outputs": [],
"source": [
"if viewer is not None:\n",
" for layer in viewer.layers:\n",
" viewer.layers.remove(layer)\n",
" viewer.add_image(testX[0], name=\"Test volume\", colormap=\"gray\")\n",
" viewer.add_points(testY[0], name=\"Predicted test spots\", face_color=\"orange\", edge_color=\"orange\", size=5, symbol=\"ring\")"
]
},
{
"cell_type": "markdown",
"id": "60ce51f9-ef83-4b2e-9a70-57f4a9430487",
"metadata": {},
"source": [
"This notebook shows the most user-friendly way to train models. If you want to dive deeper into the model architecture and tweak the code and you are already comfortable with training DL models, please check [the documentation](https://weigertlab.github.io/spotiflow) to get started."
]
},
{
"cell_type": "markdown",
"id": "a6d34ee6",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "cbaidt",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
14 changes: 9 additions & 5 deletions spotiflow/model/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import abc
import argparse
import logging
import json
import yaml
import logging
import sys

from numbers import Number
from pathlib import Path
from typing import Literal, Optional, Tuple, Union

import numpy as np
import yaml

logging.basicConfig(level=logging.INFO, stream=sys.stdout)
log = logging.getLogger(__name__)

Expand Down Expand Up @@ -144,6 +145,8 @@ def __init__(
self.is_3d = bool(is_3d)
if isinstance(grid, int):
self.grid = (grid,)*(3 if self.is_3d else 2)
elif not isinstance(grid, tuple):
self.grid = tuple(grid)
else:
self.grid = grid

Expand All @@ -163,8 +166,8 @@ def is_valid(self):
isinstance(self.in_channels, int) and self.in_channels > 0
), "in_channels must be greater than 0"
assert (
isinstance(self.out_channels, int) and self.out_channels > 0
), "out_channels must be greater than 0"
isinstance(self.out_channels, int) and self.out_channels == 1
), "out_channels must be equal to 1 (multi-channel output not supported yet)"
assert (
isinstance(self.initial_fmaps, int) and self.initial_fmaps > 0
), "initial_fmaps must be greater than 0"
Expand Down Expand Up @@ -212,6 +215,7 @@ def is_valid(self):
assert (
all(isinstance(s, int) and s > 0 and (s == 1 or s % 2 == 0) for s in self.grid)
), "grid must be a tuple containing only 1 or even positive integers"
assert len(np.unique(self.grid)) == 1, "grid must currently be isotropic (all dimensions must be equal)"


class SpotiflowTrainingConfig(SpotiflowConfig):
Expand Down
14 changes: 11 additions & 3 deletions spotiflow/model/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class RegisteredModel:

url: str
md5_hash: str

is_3d: bool

def list_registered():
return list(_REGISTERED.keys())
Expand Down Expand Up @@ -50,13 +50,21 @@ def get_pretrained_model_path(name: str):
"hybiss": RegisteredModel(
url="https://drive.switch.ch/index.php/s/O4hqFSSGX6veLwa/download",
md5_hash="254afa97c137d0bd74fd9c1827f0e323",
is_3d=False,
),
"general": RegisteredModel(
url="https://drive.switch.ch/index.php/s/6AoTEgpIAeQMRvX/download",
md5_hash="9dd31a36b737204e91b040515e3d899e"
md5_hash="9dd31a36b737204e91b040515e3d899e",
is_3d=False,
),
"synth_complex": RegisteredModel(
url="https://drive.switch.ch/index.php/s/CiCjNJaJzpVVD2M/download",
md5_hash="d692fa21da47e4a50b4c52f49442508b"
md5_hash="d692fa21da47e4a50b4c52f49442508b",
is_3d=False,
),
"smfish_3d": RegisteredModel(
url="https://drive.switch.ch/index.php/s/Vym7tqiORZOP5Zt/download",
md5_hash="c5ab30ba3b9ccb07b4c34442d1b5b615",
is_3d=True,
)
}
3 changes: 1 addition & 2 deletions spotiflow/model/spotiflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ def predict(
points.append(curr_pts)
probs.append(curr_probs)

assert self.config.out_channels == 1, "Trying to predict using a multi-channel network, which is not supported yet."
# ! FIXME: This is a temporary fix which will stop working when multi-channel output is implemented
points = points[0]
probs = probs[0]
Expand All @@ -833,8 +834,6 @@ def predict(
flow = flow[0]

else: # Predict with tiling
if self.config.out_channels > 1:
raise NotImplementedError("Tiled prediction not implemented for multi-channel output yet.")
padded_shape = tuple(np.array(x.shape[:actual_n_dims])//corr_grid)
if not skip_details:
y = np.empty(padded_shape, np.float32)
Expand Down
Loading

0 comments on commit 85d6319

Please sign in to comment.