From df65b53db076b7a139e0bb9985305a0c1610dba8 Mon Sep 17 00:00:00 2001 From: cssprad1 Date: Thu, 2 May 2024 17:07:03 -0400 Subject: [PATCH] made small updates to the recon notebook --- notebooks/satvision-toa-reconstruction.ipynb | 182 +++++++++++++------ 1 file changed, 124 insertions(+), 58 deletions(-) diff --git a/notebooks/satvision-toa-reconstruction.ipynb b/notebooks/satvision-toa-reconstruction.ipynb index b4b8f1e..e2996f1 100644 --- a/notebooks/satvision-toa-reconstruction.ipynb +++ b/notebooks/satvision-toa-reconstruction.ipynb @@ -7,7 +7,7 @@ "source": [ "# Satvision-TOA Reconstruction Notebook\n", "\n", - "Version: 02.20.24\n", + "Version: 04.30.24\n", "\n", "Env: `Python [conda env:ilab-pytorch]`" ] @@ -34,6 +34,7 @@ "import time\n", "import random\n", "import datetime\n", + "from tqdm import tqdm\n", "import numpy as np\n", "import logging\n", "\n", @@ -77,7 +78,7 @@ "id": "d841e464-f880-4e53-bf31-f9f225713918", "metadata": {}, "source": [ - "## Configuration" + "## 1. Configuration" ] }, { @@ -93,10 +94,28 @@ "\n", "git lfs install\n", "\n", - "git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-base\n", + "git clone git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-huge-patch8-window12-192\n", "```\n", "\n", - "Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth. " + "Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth.\n", + "https://huggingface.co/docs/hub/security-git-ssh\n", + "\n", + "```bash\n", + "eval $(ssh-agent)\n", + "\n", + "# If this outputs as anon, follow the next steps.\n", + "ssh -T git@hf.co\n", + "\n", + "# Check if ssh-agent is using the proper key\n", + "ssh-add -l\n", + "\n", + "# If not\n", + "ssh-add ~/.ssh/your-key\n", + "\n", + "# Or if you want to use the default id_* key, just do\n", + "ssh-add\n", + "\n", + "```" ] }, { @@ -106,13 +125,12 @@ "metadata": {}, "outputs": [], "source": [ - "MODEL_PATH: str = '../../satvision-toa-base/satvision-toa_84M_2M_100.pth'\n", - "CONFIG_PATH: str = '../../satvision-toa-base/mim_pretrain_swinv2_satvision-toa_base_192_window12_800ep.yaml'\n", + "MODEL_PATH: str = '../../satvision-toa-huge-patch8-window12-192/mp_rank_00_model_states.pt'\n", + "CONFIG_PATH: str = '../../satvision-toa-huge-patch8-window12-192/mim_pretrain_swinv2_satvision_huge_192_window12_100ep.yaml'\n", "\n", - "BATCH_SIZE: int = 64 # Want to report loss on every image? Change to 1.\n", "OUTPUT: str = '.'\n", - "TAG: str = 'satvision-base-toa-reconstruction'\n", - "DATA_PATH: str = '/explore/nobackup/projects/ilab/projects/3DClouds/data/mosaic-v3/webdatasets'\n", + "TAG: str = 'satvision-huge-toa-reconstruction'\n", + "DATA_PATH: str = '/explore/nobackup/projects/ilab/projects/3DClouds/data/validation/sv_toa_128_chip_validation_04_24.npy'\n", "DATA_PATHS: list = [DATA_PATH]" ] }, @@ -131,7 +149,6 @@ "config.defrost()\n", "config.MODEL.RESUME = MODEL_PATH\n", "config.DATA.DATA_PATHS = DATA_PATHS\n", - "config.DATA.BATCH_SIZE = BATCH_SIZE\n", "config.OUTPUT = OUTPUT\n", "config.TAG = TAG\n", "config.freeze()" @@ -160,6 +177,14 @@ "logger.addHandler(console)" ] }, + { + "cell_type": "markdown", + "id": "11ebd497-7741-41a7-af9d-0ee49a6313a4", + "metadata": {}, + "source": [ + "## 2. Load model weights from checkpoint" + ] + }, { "cell_type": "code", "execution_count": null, @@ -169,7 +194,7 @@ "source": [ "checkpoint = torch.load(MODEL_PATH)\n", "model = build_model(config, pretrain=True)\n", - "model.load_state_dict(checkpoint['model'])\n", + "model.load_state_dict(checkpoint['module']) # If 'module' not working, try 'model'\n", "n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "logger.info(f\"number of params: {n_parameters}\")\n", "model.cuda()\n", @@ -178,20 +203,36 @@ }, { "cell_type": "markdown", - "id": "bd9ba52e-62ca-4800-b2aa-deaaea64be9f", + "id": "b500d13b-89d7-4cd8-a36a-ab6f10f6a397", "metadata": {}, "source": [ - "## Dataloader" + "## 3. Load evaluation set (from numpy file)" ] }, { "cell_type": "code", "execution_count": null, - "id": "75c52c66-f322-413c-be76-6c7abfd159bc", + "id": "73a8d307-de9b-4617-abdd-dae1e7c2521a", "metadata": {}, "outputs": [], "source": [ - "dataloader = mim_webdataset_datamodule.build_mim_dataloader(config, logger)" + "# Use the Masked-Image-Modeling transform\n", + "transform = SimmimTransform(config)\n", + "\n", + "# The reconstruction evaluation set is a single numpy file\n", + "validation_dataset_path = config.DATA.DATA_PATHS[0]\n", + "validation_dataset = np.load(validation_dataset_path)\n", + "len_batch = range(validation_dataset.shape[0])\n", + "\n", + "# Apply transform to each image in the batch\n", + "# A mask is auto-generated in the transform\n", + "imgMasks = [transform(validation_dataset[idx]) for idx \\\n", + " in len_batch]\n", + "\n", + "# Seperate img and masks, cast masks to torch tensor\n", + "img = torch.stack([imgMask[0] for imgMask in imgMasks])\n", + "mask = torch.stack([torch.from_numpy(imgMask[1]) for \\\n", + " imgMask in imgMasks])" ] }, { @@ -199,7 +240,7 @@ "id": "55acf5e9-eb2a-496c-baa6-3b74503a2978", "metadata": {}, "source": [ - "## Prediction helper functions" + "## 4. Prediction helper functions" ] }, { @@ -215,30 +256,33 @@ " outputs = []\n", " masks = []\n", " losses = []\n", + " with tqdm(total=num_batches) as pbar:\n", "\n", - " for idx, img_mask in enumerate(dataloader):\n", - " \n", - " if idx > num_batches:\n", - " return inputs, outputs, masks, losses\n", + " for idx, img_mask in enumerate(dataloader):\n", + " \n", + " pbar.update(1)\n", + "\n", + " if idx > num_batches:\n", + " return inputs, outputs, masks, losses\n", "\n", - " img_mask = img_mask[0]\n", + " img_mask = img_mask[0]\n", "\n", - " img = torch.stack([pair[0] for pair in img_mask])\n", - " mask = torch.stack([pair[1] for pair in img_mask])\n", + " img = torch.stack([pair[0] for pair in img_mask])\n", + " mask = torch.stack([pair[1] for pair in img_mask])\n", "\n", - " img = img.cuda(non_blocking=True)\n", - " mask = mask.cuda(non_blocking=True)\n", + " img = img.cuda(non_blocking=True)\n", + " mask = mask.cuda(non_blocking=True)\n", "\n", - " with torch.no_grad():\n", - " with amp.autocast(enabled=config.ENABLE_AMP):\n", - " z = model.encoder(img, mask)\n", - " img_recon = model.decoder(z)\n", - " loss = model(img, mask)\n", + " with torch.no_grad():\n", + " with amp.autocast(enabled=config.ENABLE_AMP):\n", + " z = model.encoder(img, mask)\n", + " img_recon = model.decoder(z)\n", + " loss = model(img, mask)\n", "\n", - " inputs.extend(img.cpu())\n", - " masks.extend(mask.cpu())\n", - " outputs.extend(img_recon.cpu())\n", - " losses.append(losses)\n", + " inputs.extend(img.cpu())\n", + " masks.extend(mask.cpu())\n", + " outputs.extend(img_recon.cpu())\n", + " losses.append(loss.cpu())\n", " \n", " return inputs, outputs, masks, losses\n", "\n", @@ -253,15 +297,14 @@ "\n", "\n", "def process_mask(mask):\n", - " mask = mask.unsqueeze(0)\n", - " mask = mask.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous()\n", - " mask = mask[0, 0, :, :]\n", - " mask = np.stack([mask, mask, mask], axis=-1)\n", - " return mask\n", + " mask_img = mask.unsqueeze(0)\n", + " mask_img = mask_img.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous()\n", + " mask_img = mask_img[0, 0, :, :]\n", + " mask_img = np.stack([mask_img, mask_img, mask_img], axis=-1)\n", + " return mask_img\n", "\n", "\n", "def process_prediction(image, img_recon, mask, rgb_index):\n", - " img_normed = minmax_norm(image.numpy())\n", "\n", " mask = process_mask(mask)\n", " \n", @@ -269,16 +312,19 @@ " blue_idx = rgb_index[1]\n", " green_idx = rgb_index[2]\n", "\n", - " rgb_image = np.stack((img_normed[red_idx, :, :],\n", - " img_normed[blue_idx, :, :],\n", - " img_normed[green_idx, :, :]),\n", - " axis=-1)\n", + " image = image.numpy()\n", + " rgb_image = np.stack((image[red_idx, :, :],\n", + " image[blue_idx, :, :],\n", + " image[green_idx, :, :]),\n", + " axis=-1)\n", + " rgb_image = minmax_norm(rgb_image)\n", "\n", - " img_recon = minmax_norm(img_recon.numpy())\n", + " img_recon = img_recon.numpy()\n", " rgb_image_recon = np.stack((img_recon[red_idx, :, :],\n", " img_recon[blue_idx, :, :],\n", " img_recon[green_idx, :, :]),\n", " axis=-1)\n", + " rgb_image_recon = minmax_norm(rgb_image_recon)\n", "\n", " rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon)\n", " rgb_image_masked = np.where(mask == 1, 0, rgb_image)\n", @@ -287,11 +333,10 @@ " return rgb_image, rgb_image_masked, rgb_recon_masked, mask\n", "\n", "\n", - "def plot_export_pdf(path, num_sample, inputs, outputs, masks, rgb_index):\n", - " random_subsample = random.sample(range(len(inputs)), num_sample)\n", + "def plot_export_pdf(path, inputs, outputs, masks, rgb_index):\n", " pdf_plot_obj = PdfPages(path)\n", "\n", - " for idx in random_subsample:\n", + " for idx in range(len(inputs)):\n", " # prediction processing\n", " image = inputs[idx]\n", " img_recon = outputs[idx]\n", @@ -324,19 +369,39 @@ "id": "551c44b5-6d88-45c4-b397-c38de8064544", "metadata": {}, "source": [ - "## Predict" + "## 5. Predict" ] }, { "cell_type": "code", "execution_count": null, - "id": "fa43bfaf-6379-43d5-9be3-bd0e55f5ca12", + "id": "4e695cc3-b869-4fc2-b360-b45f3b81affd", "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "\n", - "inputs, outputs, masks, losses = predict(model, dataloader, num_batches=5)" + "inputs = []\n", + "outputs = []\n", + "masks = []\n", + "losses = []\n", + "\n", + "# We could do this in a single batch however we\n", + "# want to report the loss per-image, in place of\n", + "# loss per-batch.\n", + "for i in tqdm(range(img.shape[0])):\n", + " single_img = img[i].unsqueeze(0)\n", + " single_mask = mask[i].unsqueeze(0)\n", + " single_img = single_img.cuda(non_blocking=True)\n", + " single_mask = single_mask.cuda(non_blocking=True)\n", + "\n", + " with torch.no_grad():\n", + " z = model.encoder(single_img, single_mask)\n", + " img_recon = model.decoder(z)\n", + " loss = model(single_img, single_mask)\n", + "\n", + " inputs.extend(single_img.cpu())\n", + " masks.extend(single_mask.cpu())\n", + " outputs.extend(img_recon.cpu())\n", + " losses.append(loss.cpu()) " ] }, { @@ -344,7 +409,9 @@ "id": "dc3f102c-94df-4d9e-8040-52197a7e71db", "metadata": {}, "source": [ - "## Plot and write to PDF" + "## 6. Plot and write to PDF\n", + "\n", + "Writes out all of the predictions to a PDF file" ] }, { @@ -354,11 +421,10 @@ "metadata": {}, "outputs": [], "source": [ - "pdf_path = '../../satvision-toa-reconstruction-pdf-02.20.pdf'\n", - "num_samples = 10 # Number of random samples from the predictions\n", - "rgb_index = [0, 3, 2] # Indices of [Red band, Blue band, Green band]\n", + "pdf_path = '../../satvision-toa-reconstruction-pdf-huge-patch-8-04.30.pdf'\n", + "rgb_index = [0, 2, 1] # Indices of [Red band, Blue band, Green band]\n", "\n", - "plot_export_pdf(pdf_path, num_samples, inputs, outputs, masks, rgb_index)" + "plot_export_pdf(pdf_path, inputs, outputs, masks, rgb_index)" ] }, {