-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
eb96650
commit 85d6319
Showing
7 changed files
with
316 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6dca17b2-9f03-4519-9e5a-19241ce76b73", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"# Spotiflow: training your own model in 3D" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4b3ebdc4-b6a2-4a0e-b553-06ecf037cc37", | ||
"metadata": {}, | ||
"source": [ | ||
"**NOTE**: this notebook requires `napari` to be installed if you want to visualize the data (optional but recommended). You can install it e.g. via `pip install napari[all]` (see [the instructions](https://napari.org/stable/tutorials/fundamentals/installation.html) if you have any issue).\n", | ||
"\n", | ||
"Let's first load all the libraries we're gonna need to detect spots in our volumes." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c2138a5b-4ed2-4c00-ab5b-0d289ffa41e1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from spotiflow.model import Spotiflow, SpotiflowModelConfig\n", | ||
"from spotiflow.sample_data import load_dataset\n", | ||
"from spotiflow.utils import get_data\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "98b51535-bf2f-4ebd-bcbe-40d561c9792e", | ||
"metadata": {}, | ||
"source": [ | ||
"Similarly to the 2D case, we first load our dataset. We will use the `synth_3d` dataset (corresponding to `synthetic-3d` in the paper, which is a good starting point if you want to then fine-tune on your own data). If you have your own annotated data, you can load it and store it in six different variables corresponding to the training images and spots, to the validation images and spots and to the test images and spots. You can use the `load_data()` function to that end (please [see the docs](https://weigertlab.github.io/spotiflow) to check the data format that the function allows)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "579a5acb-1ebc-436b-a794-23ce6e9f94d1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trX, trY, valX, valY, testX, testY = load_dataset(\"synth_3d\", include_test=True)\n", | ||
"# trX, trY, valX, valY, testX, testY = get_data(\"/FOLDER/WITH/DATA\", include_test=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "49bd5718", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trY[0].shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "74a14129-713d-432d-9e32-59dc1463d86d", | ||
"metadata": {}, | ||
"source": [ | ||
"The first two variables should contain the training images and annotations, while the latter the validation ones. While visualizing the images in Python is quite straightforward, that is not the case for 3D volumes. We will use the `napari` library to visualize the volumes. If you don't have it installed, you can do so by checking the first cell in the notebook. The cell below won't run if you don't have `napari` installed, but you can still run the rest of the notebook without it." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ca5e446c-fea5-44df-a6cb-3058e0ce7985", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"try:\n", | ||
" import napari\n", | ||
" viewer = napari.Viewer(ndisplay=3)\n", | ||
" viewer.add_image(trX[0], name=\"Training volume\", colormap=\"gray\")\n", | ||
" viewer.add_points(trY[0], name=\"Training spots\", face_color=\"orange\", edge_color=\"orange\", size=5, symbol=\"ring\")\n", | ||
"except ImportError as _:\n", | ||
" print(\"napari not installed, skipping visualization\")\n", | ||
" viewer = None\n", | ||
"except Exception as e:\n", | ||
" raise e" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6b385ee0-323d-4efa-be41-9b9fe423017c", | ||
"metadata": {}, | ||
"source": [ | ||
"Training with the default model configuration is straightforward, althought not as much as in the 2D case. First we need to instantiate the model configuration (check [the documentation](https://weigertlab.github.io/spotiflow) for more information about other options):" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a9e41080-6b9f-44b8-aa7f-b8f533ccb60b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"config = SpotiflowModelConfig(\n", | ||
" is_3d=True, # 3D model\n", | ||
" grid=(2, 2, 2), # predict on a downsampled grid, this is the value used in the paper\n", | ||
")\n", | ||
"model = Spotiflow(config=config)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "580f0445-4cb6-49da-9e18-512e7e501281", | ||
"metadata": {}, | ||
"source": [ | ||
"We can now train the model with calling `.fit()` after setting where we want the model to be stored. Again, you need to define the training parameters. If you want to change some values (_e.g._ the number of epochs), simply change the parameter accordingly (for more information, check [the documentation](https://weigertlab.github.io/spotiflow)):" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f37e8f49-cc73-4761-a2dc-17f181ac4d66", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"save_folder = \"models/synth_3d\" # change to where you want to store the model\n", | ||
"train_config = {\n", | ||
" \"num_epochs\": 5,\n", | ||
" \"crop_size\": 128,\n", | ||
" \"crop_size_depth\": 32,\n", | ||
" \"smart_crop\": True,\n", | ||
"}\n", | ||
"model.fit(\n", | ||
" trX,\n", | ||
" trY,\n", | ||
" valX,\n", | ||
" valY,\n", | ||
" save_dir=save_folder,\n", | ||
" train_config=train_config,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f0000786-69b0-4332-8346-c73b86bf3415", | ||
"metadata": {}, | ||
"source": [ | ||
"Our model is now ready to be used! Let's first check the save folder to make sure the model was stored properly (there should be two `.pt` files (`best.pt` and `last.pt`) as well as three `.yaml` configuration files.)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "79da07cb-100c-4ff6-9588-a83e9b6286f2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!ls $save_folder" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "55ca3e53-89d6-459b-bf97-af871b34fcab", | ||
"metadata": {}, | ||
"source": [ | ||
"We can also quickly predict on a test image which was not seen during training (see [the inference notebook](./2_inference.ipynb) for more information about predicting as well as model loading): " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f1972096-aa37-40a1-b90b-900c61202f0c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"test_pred, _ = model.predict(testX[0], device=\"auto\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "dd4a6108", | ||
"metadata": {}, | ||
"source": [ | ||
"Let's visualize the results now using `napari` (if it is already running):" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "68c2682e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"if viewer is not None:\n", | ||
" for layer in viewer.layers:\n", | ||
" viewer.layers.remove(layer)\n", | ||
" viewer.add_image(testX[0], name=\"Test volume\", colormap=\"gray\")\n", | ||
" viewer.add_points(testY[0], name=\"Predicted test spots\", face_color=\"orange\", edge_color=\"orange\", size=5, symbol=\"ring\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "60ce51f9-ef83-4b2e-9a70-57f4a9430487", | ||
"metadata": {}, | ||
"source": [ | ||
"This notebook shows the most user-friendly way to train models. If you want to dive deeper into the model architecture and tweak the code and you are already comfortable with training DL models, please check [the documentation](https://weigertlab.github.io/spotiflow) to get started." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a6d34ee6", | ||
"metadata": {}, | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "cbaidt", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.16" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.