Skip to content

Commit

Permalink
smaller model
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Salmon committed Aug 22, 2024
1 parent e82fd1a commit e2782c7
Showing 1 changed file with 59 additions and 29 deletions.
88 changes: 59 additions & 29 deletions 03_COSDD/solution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -53,7 +53,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -120,7 +120,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"tags": [
"solution"
Expand Down Expand Up @@ -359,12 +359,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"real_batch_size = 4\n",
"n_grad_batches = 4\n",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Effective batch size: 16\n"
]
}
],
"source": [
"real_batch_size = 16\n",
"n_grad_batches = 1\n",
"print(f\"Effective batch size: {real_batch_size * n_grad_batches}\")\n",
"crop_size = (256, 256)\n",
"train_split = 0.9\n",
Expand Down Expand Up @@ -464,9 +472,9 @@
"outputs": [],
"source": [
"dimensions = ... ### Insert a value here\n",
"s_code_channels = 64\n",
"s_code_channels = 16\n",
"\n",
"n_layers = 6\n",
"n_layers = 4\n",
"z_dims = [s_code_channels // 2] * n_layers\n",
"downsampling = [1] * n_layers\n",
"lvae = LadderVAE(\n",
Expand All @@ -484,9 +492,9 @@
" s_code_channels=s_code_channels,\n",
" kernel_size=5,\n",
" noise_direction=... ### Insert a value here\n",
" n_filters=64,\n",
" n_layers=4,\n",
" n_gaussians=5,\n",
" n_filters=16,\n",
" n_layers=3,\n",
" n_gaussians=4,\n",
" dimensions=dimensions,\n",
")\n",
"\n",
Expand Down Expand Up @@ -518,24 +526,35 @@
" data_mean=low_snr.mean(),\n",
" data_std=low_snr.std(),\n",
" n_grad_batches=n_grad_batches,\n",
" checkpointed=True,\n",
" checkpointed=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {
"tags": [
"solution"
]
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'vae' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['vae'])`.\n",
"/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'ar_decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['ar_decoder'])`.\n",
"/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 's_decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['s_decoder'])`.\n",
"/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'direct_denoiser' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['direct_denoiser'])`.\n"
]
}
],
"source": [
"dimensions = 2 ### Insert a value here\n",
"s_code_channels = 64\n",
"s_code_channels = 16\n",
"\n",
"n_layers = 6\n",
"n_layers = 4\n",
"z_dims = [s_code_channels // 2] * n_layers\n",
"downsampling = [1] * n_layers\n",
"lvae = LadderVAE(\n",
Expand All @@ -553,8 +572,8 @@
" s_code_channels=s_code_channels,\n",
" kernel_size=5,\n",
" noise_direction=\"x\", ### Insert a value here\n",
" n_filters=64,\n",
" n_layers=4,\n",
" n_filters=16,\n",
" n_layers=3,\n",
" n_gaussians=4,\n",
" dimensions=dimensions,\n",
")\n",
Expand Down Expand Up @@ -587,7 +606,7 @@
" data_mean=low_snr.mean(),\n",
" data_std=low_snr.std(),\n",
" n_grad_batches=n_grad_batches,\n",
" checkpointed=True,\n",
" checkpointed=False,\n",
")"
]
},
Expand All @@ -613,7 +632,7 @@
"3. Enter `tensorboard --logdir 05_image_restoration/03_COSDD/checkpoints`\n",
"4. Finally, open a browser and enter localhost:6006 in the address bar.\n",
"\n",
"Once you're in tensorboard, you'll see the training logs of your model and the logs of a model that's been trained for 3.5 hours.\n",
"Once you're in tensorboard, you'll see the training logs of your model and the logs of a model that's already been trained for 3.5 hours.\n",
"</div>"
]
},
Expand Down Expand Up @@ -702,21 +721,32 @@
" devices=1,\n",
" max_epochs=max_epochs,\n",
" max_time=max_time, # Remove this time limit to train the model fully\n",
" log_every_n_steps=len(train_set) // (4 * real_batch_size),\n",
" log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n",
" callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n",
" precision=\"bf16-mixed\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {
"tags": [
"solution"
]
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using bfloat16 Automatic Mixed Precision (AMP)\n",
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n"
]
}
],
"source": [
"model_name = \"mito-confocal\" ### Insert a value here\n",
"checkpoint_path = os.path.join(\"checkpoints\", model_name)\n",
Expand All @@ -732,7 +762,7 @@
" devices=1,\n",
" max_epochs=max_epochs,\n",
" max_time=max_time, # Remove this time limit to train the model fully\n",
" log_every_n_steps=len(train_set) // (4 * real_batch_size),\n",
" log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n",
" callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n",
" precision=\"bf16-mixed\",\n",
")"
Expand Down Expand Up @@ -783,7 +813,7 @@
"metadata": {},
"source": [
"### 2.1. Load test data\n",
"The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 10 to speed up inference."
"The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 2 to speed up inference."
]
},
{
Expand All @@ -793,7 +823,7 @@
"outputs": [],
"source": [
"lowsnr_path = \"./../data/mito-confocal-lowsnr.tif\"\n",
"n_test_images = 10\n",
"n_test_images = 2\n",
"# load the data\n",
"test_set = tifffile.imread(lowsnr_path)\n",
"test_set = test_set[:n_test_images, np.newaxis]\n",
Expand Down

0 comments on commit e2782c7

Please sign in to comment.