diff --git a/notebooks/satvision-toa-reconstruction.ipynb b/notebooks/satvision-toa-reconstruction.ipynb index 2f98f75..b7a9dde 100644 --- a/notebooks/satvision-toa-reconstruction.ipynb +++ b/notebooks/satvision-toa-reconstruction.ipynb @@ -7,7 +7,7 @@ "source": [ "# Satvision-TOA Reconstruction Notebook\n", "\n", - "Version: 03.15.24\n", + "Version: 04.30.24\n", "\n", "Env: `Python [conda env:ilab-pytorch]`" ] @@ -34,6 +34,7 @@ "import time\n", "import random\n", "import datetime\n", + "from tqdm import tqdm\n", "import numpy as np\n", "import logging\n", "\n", @@ -75,7 +76,7 @@ "id": "d841e464-f880-4e53-bf31-f9f225713918", "metadata": {}, "source": [ - "## Configuration" + "## 1. Configuration" ] }, { @@ -91,20 +92,27 @@ "\n", "git lfs install\n", "\n", - "git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-huge\n", + "git clone git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-huge-patch8-window12-192\n", "```\n", "\n", - "Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth. \n", + "Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth.\n", + "https://huggingface.co/docs/hub/security-git-ssh\n", "\n", - "If experiencing ssh-related authentication issues:\n", "```bash\n", - "eval `ssh-agent -s` # starts ssh-agent\n", + "eval $(ssh-agent)\n", "\n", - "ssh-add -l # is your ssh key added to the agent?\n", + "# If this outputs as anon, follow the next steps.\n", + "ssh -T git@hf.co\n", "\n", - "ssh-add ~/.ssh/id_xxxx # adds ssh ID to ssh-agent\n", + "# Check if ssh-agent is using the proper key\n", + "ssh-add -l\n", + "\n", + "# If not\n", + "ssh-add ~/.ssh/your-key\n", + "\n", + "# Or if you want to use the default id_* key, just do\n", + "ssh-add\n", "\n", - "ssh -T git@hf.co # Should return \"Hi , welcome to Hugging Face.\"\n", "```" ] }, @@ -115,13 +123,12 @@ "metadata": {}, "outputs": [], "source": [ - "MODEL_PATH: str = '../../satvision-toa-huge/ckpt_epoch_100.pth'\n", - "CONFIG_PATH: str = '../../satvision-toa-huge/mim_pretrain_swinv2_satvision_huge_192_window12_200ep.yaml'\n", + "MODEL_PATH: str = '../../satvision-toa-huge-patch8-window12-192/mp_rank_00_model_states.pt'\n", + "CONFIG_PATH: str = '../../satvision-toa-huge-patch8-window12-192/mim_pretrain_swinv2_satvision_huge_192_window12_100ep.yaml'\n", "\n", - "BATCH_SIZE: int = 1 # Want to report loss on every image? Change to 1.\n", "OUTPUT: str = '.'\n", - "TAG: str = 'satvision-base-toa-reconstruction'\n", - "DATA_PATH: str = '/explore/nobackup/projects/ilab/projects/3DClouds/data/mosaic-v3/webdatasets'\n", + "TAG: str = 'satvision-huge-toa-reconstruction'\n", + "DATA_PATH: str = '/explore/nobackup/projects/ilab/projects/3DClouds/data/validation/sv_toa_128_chip_validation_04_24.npy'\n", "DATA_PATHS: list = [DATA_PATH]" ] }, @@ -140,7 +147,6 @@ "config.defrost()\n", "config.MODEL.RESUME = MODEL_PATH\n", "config.DATA.DATA_PATHS = DATA_PATHS\n", - "config.DATA.BATCH_SIZE = BATCH_SIZE\n", "config.OUTPUT = OUTPUT\n", "config.TAG = TAG\n", "config.freeze()" @@ -169,6 +175,14 @@ "logger.addHandler(console)" ] }, + { + "cell_type": "markdown", + "id": "11ebd497-7741-41a7-af9d-0ee49a6313a4", + "metadata": {}, + "source": [ + "## 2. Load model weights from checkpoint" + ] + }, { "cell_type": "code", "execution_count": null, @@ -178,7 +192,7 @@ "source": [ "checkpoint = torch.load(MODEL_PATH)\n", "model = build_model(config, pretrain=True)\n", - "model.load_state_dict(checkpoint['model'])\n", + "model.load_state_dict(checkpoint['module']) # If 'module' not working, try 'model'\n", "n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "logger.info(f\"number of params: {n_parameters}\")\n", "model.cuda()\n", @@ -187,32 +201,36 @@ }, { "cell_type": "markdown", - "id": "bd9ba52e-62ca-4800-b2aa-deaaea64be9f", + "id": "b500d13b-89d7-4cd8-a36a-ab6f10f6a397", "metadata": {}, "source": [ - "## Dataloader" + "## 3. Load evaluation set (from numpy file)" ] }, { "cell_type": "code", "execution_count": null, - "id": "75c52c66-f322-413c-be76-6c7abfd159bc", + "id": "73a8d307-de9b-4617-abdd-dae1e7c2521a", "metadata": {}, "outputs": [], "source": [ - "dataset = MODIS22MDataset(config,\n", - " config.DATA.DATA_PATHS,\n", - " split=\"train\",\n", - " img_size=config.DATA.IMG_SIZE,\n", - " transform=SimmimTransform(config),\n", - " batch_size=config.DATA.BATCH_SIZE).dataset()\n", - "\n", - "dataloader = torch.utils.data.DataLoader(\n", - " dataset,\n", - " batch_size=None, # Change if not using webdataset as underlying dataset type\n", - " num_workers=15,\n", - " shuffle=False,\n", - " pin_memory=True,)" + "# Use the Masked-Image-Modeling transform\n", + "transform = SimmimTransform(config)\n", + "\n", + "# The reconstruction evaluation set is a single numpy file\n", + "validation_dataset_path = config.DATA.DATA_PATHS[0]\n", + "validation_dataset = np.load(validation_dataset_path)\n", + "len_batch = range(validation_dataset.shape[0])\n", + "\n", + "# Apply transform to each image in the batch\n", + "# A mask is auto-generated in the transform\n", + "imgMasks = [transform(validation_dataset[idx]) for idx \\\n", + " in len_batch]\n", + "\n", + "# Seperate img and masks, cast masks to torch tensor\n", + "img = torch.stack([imgMask[0] for imgMask in imgMasks])\n", + "mask = torch.stack([torch.from_numpy(imgMask[1]) for \\\n", + " imgMask in imgMasks])" ] }, { @@ -220,7 +238,7 @@ "id": "55acf5e9-eb2a-496c-baa6-3b74503a2978", "metadata": {}, "source": [ - "## Prediction helper functions" + "## 4. Prediction helper functions" ] }, { @@ -236,30 +254,33 @@ " outputs = []\n", " masks = []\n", " losses = []\n", + " with tqdm(total=num_batches) as pbar:\n", "\n", - " for idx, img_mask in enumerate(dataloader):\n", - " \n", - " if idx > num_batches:\n", - " return inputs, outputs, masks, losses\n", + " for idx, img_mask in enumerate(dataloader):\n", + " \n", + " pbar.update(1)\n", "\n", - " img_mask = img_mask[0]\n", + " if idx > num_batches:\n", + " return inputs, outputs, masks, losses\n", "\n", - " img = torch.stack([pair[0] for pair in img_mask])\n", - " mask = torch.stack([pair[1] for pair in img_mask])\n", + " img_mask = img_mask[0]\n", "\n", - " img = img.cuda(non_blocking=True)\n", - " mask = mask.cuda(non_blocking=True)\n", + " img = torch.stack([pair[0] for pair in img_mask])\n", + " mask = torch.stack([pair[1] for pair in img_mask])\n", "\n", - " with torch.no_grad():\n", - " with amp.autocast(enabled=config.ENABLE_AMP):\n", - " z = model.encoder(img, mask)\n", - " img_recon = model.decoder(z)\n", - " loss = model(img, mask)\n", + " img = img.cuda(non_blocking=True)\n", + " mask = mask.cuda(non_blocking=True)\n", "\n", - " inputs.extend(img.cpu())\n", - " masks.extend(mask.cpu())\n", - " outputs.extend(img_recon.cpu())\n", - " losses.append(loss.cpu())\n", + " with torch.no_grad():\n", + " with amp.autocast(enabled=config.ENABLE_AMP):\n", + " z = model.encoder(img, mask)\n", + " img_recon = model.decoder(z)\n", + " loss = model(img, mask)\n", + "\n", + " inputs.extend(img.cpu())\n", + " masks.extend(mask.cpu())\n", + " outputs.extend(img_recon.cpu())\n", + " losses.append(loss.cpu())\n", " \n", " return inputs, outputs, masks, losses\n", "\n", @@ -274,11 +295,11 @@ "\n", "\n", "def process_mask(mask):\n", - " mask = mask.unsqueeze(0)\n", - " mask = mask.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous()\n", - " mask = mask[0, 0, :, :]\n", - " mask = np.stack([mask, mask, mask], axis=-1)\n", - " return mask\n", + " mask_img = mask.unsqueeze(0)\n", + " mask_img = mask_img.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous()\n", + " mask_img = mask_img[0, 0, :, :]\n", + " mask_img = np.stack([mask_img, mask_img, mask_img], axis=-1)\n", + " return mask_img\n", "\n", "\n", "def process_prediction(image, img_recon, mask, rgb_index):\n", @@ -310,11 +331,10 @@ " return rgb_image, rgb_image_masked, rgb_recon_masked, mask\n", "\n", "\n", - "def plot_export_pdf(path, num_sample, inputs, outputs, masks, rgb_index):\n", - " random_subsample = random.sample(range(len(inputs)), num_sample)\n", + "def plot_export_pdf(path, inputs, outputs, masks, rgb_index):\n", " pdf_plot_obj = PdfPages(path)\n", "\n", - " for idx in random_subsample:\n", + " for idx in range(len(inputs)):\n", " # prediction processing\n", " image = inputs[idx]\n", " img_recon = outputs[idx]\n", @@ -347,19 +367,39 @@ "id": "551c44b5-6d88-45c4-b397-c38de8064544", "metadata": {}, "source": [ - "## Predict" + "## 5. Predict" ] }, { "cell_type": "code", "execution_count": null, - "id": "fa43bfaf-6379-43d5-9be3-bd0e55f5ca12", + "id": "4e695cc3-b869-4fc2-b360-b45f3b81affd", "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "\n", - "inputs, outputs, masks, losses = predict(model, dataloader, num_batches=64)" + "inputs = []\n", + "outputs = []\n", + "masks = []\n", + "losses = []\n", + "\n", + "# We could do this in a single batch however we\n", + "# want to report the loss per-image, in place of\n", + "# loss per-batch.\n", + "for i in tqdm(range(img.shape[0])):\n", + " single_img = img[i].unsqueeze(0)\n", + " single_mask = mask[i].unsqueeze(0)\n", + " single_img = single_img.cuda(non_blocking=True)\n", + " single_mask = single_mask.cuda(non_blocking=True)\n", + "\n", + " with torch.no_grad():\n", + " z = model.encoder(single_img, single_mask)\n", + " img_recon = model.decoder(z)\n", + " loss = model(single_img, single_mask)\n", + "\n", + " inputs.extend(single_img.cpu())\n", + " masks.extend(single_mask.cpu())\n", + " outputs.extend(img_recon.cpu())\n", + " losses.append(loss.cpu()) " ] }, { @@ -367,7 +407,9 @@ "id": "dc3f102c-94df-4d9e-8040-52197a7e71db", "metadata": {}, "source": [ - "## Plot and write to PDF" + "## 6. Plot and write to PDF\n", + "\n", + "Writes out all of the predictions to a PDF file" ] }, { @@ -377,11 +419,10 @@ "metadata": {}, "outputs": [], "source": [ - "pdf_path = '../../satvision-toa-reconstruction-pdf-03.15.16patch.huge.001.pdf'\n", - "num_samples = 25 # Number of random samples from the predictions\n", + "pdf_path = '../../satvision-toa-reconstruction-pdf-huge-patch-8-04.30.pdf'\n", "rgb_index = [0, 2, 1] # Indices of [Red band, Blue band, Green band]\n", "\n", - "plot_export_pdf(pdf_path, num_samples, inputs, outputs, masks, rgb_index)" + "plot_export_pdf(pdf_path, inputs, outputs, masks, rgb_index)" ] }, {