Skip to content

Commit

Permalink
32 precision
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Salmon committed Aug 22, 2024
1 parent 1cf9e2a commit 1e1b7a8
Showing 1 changed file with 25 additions and 70 deletions.
95 changes: 25 additions & 70 deletions 03_COSDD/solution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import logging\n",
"\n",
"import torch\n",
"import tifffile\n",
Expand All @@ -53,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -120,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"tags": [
"solution"
Expand Down Expand Up @@ -359,20 +358,12 @@
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Effective batch size: 16\n"
]
}
],
"source": [
"real_batch_size = 16\n",
"n_grad_batches = 1\n",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"real_batch_size = 4\n",
"n_grad_batches = 4\n",
"print(f\"Effective batch size: {real_batch_size * n_grad_batches}\")\n",
"crop_size = (256, 256)\n",
"train_split = 0.9\n",
Expand Down Expand Up @@ -472,9 +463,9 @@
"outputs": [],
"source": [
"dimensions = ... ### Insert a value here\n",
"s_code_channels = 16\n",
"s_code_channels = 32\n",
"\n",
"n_layers = 4\n",
"n_layers = 6\n",
"z_dims = [s_code_channels // 2] * n_layers\n",
"downsampling = [1] * n_layers\n",
"lvae = LadderVAE(\n",
Expand All @@ -492,8 +483,8 @@
" s_code_channels=s_code_channels,\n",
" kernel_size=5,\n",
" noise_direction=... ### Insert a value here\n",
" n_filters=16,\n",
" n_layers=3,\n",
" n_filters=32,\n",
" n_layers=4,\n",
" n_gaussians=4,\n",
" dimensions=dimensions,\n",
")\n",
Expand Down Expand Up @@ -526,35 +517,24 @@
" data_mean=low_snr.mean(),\n",
" data_std=low_snr.std(),\n",
" n_grad_batches=n_grad_batches,\n",
" checkpointed=False,\n",
" checkpointed=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"tags": [
"solution"
]
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'vae' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['vae'])`.\n",
"/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'ar_decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['ar_decoder'])`.\n",
"/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 's_decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['s_decoder'])`.\n",
"/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'direct_denoiser' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['direct_denoiser'])`.\n"
]
}
],
"outputs": [],
"source": [
"dimensions = 2 ### Insert a value here\n",
"s_code_channels = 16\n",
"s_code_channels = 32\n",
"\n",
"n_layers = 4\n",
"n_layers = 6\n",
"z_dims = [s_code_channels // 2] * n_layers\n",
"downsampling = [1] * n_layers\n",
"lvae = LadderVAE(\n",
Expand All @@ -572,8 +552,8 @@
" s_code_channels=s_code_channels,\n",
" kernel_size=5,\n",
" noise_direction=\"x\", ### Insert a value here\n",
" n_filters=16,\n",
" n_layers=3,\n",
" n_filters=32,\n",
" n_layers=4,\n",
" n_gaussians=4,\n",
" dimensions=dimensions,\n",
")\n",
Expand Down Expand Up @@ -606,7 +586,7 @@
" data_mean=low_snr.mean(),\n",
" data_std=low_snr.std(),\n",
" n_grad_batches=n_grad_batches,\n",
" checkpointed=False,\n",
" checkpointed=True,\n",
")"
]
},
Expand Down Expand Up @@ -723,30 +703,18 @@
" max_time=max_time, # Remove this time limit to train the model fully\n",
" log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n",
" callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n",
" precision=\"bf16-mixed\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {
"tags": [
"solution"
]
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using bfloat16 Automatic Mixed Precision (AMP)\n",
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n"
]
}
],
"outputs": [],
"source": [
"model_name = \"mito-confocal\" ### Insert a value here\n",
"checkpoint_path = os.path.join(\"checkpoints\", model_name)\n",
Expand All @@ -764,7 +732,6 @@
" max_time=max_time, # Remove this time limit to train the model fully\n",
" log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n",
" callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n",
" precision=\"bf16-mixed\",\n",
")"
]
},
Expand Down Expand Up @@ -798,22 +765,12 @@
"# Exercise 2. Inference with COSDD"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"logger = logging.getLogger('pytorch_lightning')\n",
"logger.setLevel(logging.WARNING)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.1. Load test data\n",
"The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 2 to speed up inference."
"The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 10 to speed up inference."
]
},
{
Expand All @@ -823,7 +780,7 @@
"outputs": [],
"source": [
"lowsnr_path = \"./../data/mito-confocal-lowsnr.tif\"\n",
"n_test_images = 2\n",
"n_test_images = 10\n",
"# load the data\n",
"test_set = tifffile.imread(lowsnr_path)\n",
"test_set = test_set[:n_test_images, np.newaxis]\n",
Expand Down Expand Up @@ -909,7 +866,6 @@
" enable_progress_bar=False,\n",
" enable_checkpointing=False,\n",
" logger=False,\n",
" precision=\"bf16-mixed\",\n",
")"
]
},
Expand All @@ -932,7 +888,6 @@
" enable_progress_bar=False,\n",
" enable_checkpointing=False,\n",
" logger=False,\n",
" precision=\"bf16-mixed\",\n",
")"
]
},
Expand Down

0 comments on commit 1e1b7a8

Please sign in to comment.