diff --git a/notebooks/satvision-toa-reconstruction.ipynb b/notebooks/satvision-toa-reconstruction.ipynb index b4b8f1e..2f98f75 100644 --- a/notebooks/satvision-toa-reconstruction.ipynb +++ b/notebooks/satvision-toa-reconstruction.ipynb @@ -7,7 +7,7 @@ "source": [ "# Satvision-TOA Reconstruction Notebook\n", "\n", - "Version: 02.20.24\n", + "Version: 03.15.24\n", "\n", "Env: `Python [conda env:ilab-pytorch]`" ] @@ -59,13 +59,11 @@ "\n", "from pytorch_caney.config import get_config\n", "\n", - "from pytorch_caney.training.mim_utils import load_checkpoint, load_pretrained\n", - "\n", "from pytorch_caney.models.build import build_model\n", "\n", "from pytorch_caney.ptc_logging import create_logger\n", "\n", - "from pytorch_caney.data.datamodules import mim_webdataset_datamodule\n", + "from pytorch_caney.data.datasets.mim_modis_22m_dataset import MODIS22MDataset\n", "\n", "from pytorch_caney.data.transforms import SimmimTransform, SimmimMaskGenerator\n", "\n", @@ -93,10 +91,21 @@ "\n", "git lfs install\n", "\n", - "git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-base\n", + "git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-huge\n", "```\n", "\n", - "Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth. " + "Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth. \n", + "\n", + "If experiencing ssh-related authentication issues:\n", + "```bash\n", + "eval `ssh-agent -s` # starts ssh-agent\n", + "\n", + "ssh-add -l # is your ssh key added to the agent?\n", + "\n", + "ssh-add ~/.ssh/id_xxxx # adds ssh ID to ssh-agent\n", + "\n", + "ssh -T git@hf.co # Should return \"Hi , welcome to Hugging Face.\"\n", + "```" ] }, { @@ -106,10 +115,10 @@ "metadata": {}, "outputs": [], "source": [ - "MODEL_PATH: str = '../../satvision-toa-base/satvision-toa_84M_2M_100.pth'\n", - "CONFIG_PATH: str = '../../satvision-toa-base/mim_pretrain_swinv2_satvision-toa_base_192_window12_800ep.yaml'\n", + "MODEL_PATH: str = '../../satvision-toa-huge/ckpt_epoch_100.pth'\n", + "CONFIG_PATH: str = '../../satvision-toa-huge/mim_pretrain_swinv2_satvision_huge_192_window12_200ep.yaml'\n", "\n", - "BATCH_SIZE: int = 64 # Want to report loss on every image? Change to 1.\n", + "BATCH_SIZE: int = 1 # Want to report loss on every image? Change to 1.\n", "OUTPUT: str = '.'\n", "TAG: str = 'satvision-base-toa-reconstruction'\n", "DATA_PATH: str = '/explore/nobackup/projects/ilab/projects/3DClouds/data/mosaic-v3/webdatasets'\n", @@ -191,7 +200,19 @@ "metadata": {}, "outputs": [], "source": [ - "dataloader = mim_webdataset_datamodule.build_mim_dataloader(config, logger)" + "dataset = MODIS22MDataset(config,\n", + " config.DATA.DATA_PATHS,\n", + " split=\"train\",\n", + " img_size=config.DATA.IMG_SIZE,\n", + " transform=SimmimTransform(config),\n", + " batch_size=config.DATA.BATCH_SIZE).dataset()\n", + "\n", + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=None, # Change if not using webdataset as underlying dataset type\n", + " num_workers=15,\n", + " shuffle=False,\n", + " pin_memory=True,)" ] }, { @@ -238,7 +259,7 @@ " inputs.extend(img.cpu())\n", " masks.extend(mask.cpu())\n", " outputs.extend(img_recon.cpu())\n", - " losses.append(losses)\n", + " losses.append(loss.cpu())\n", " \n", " return inputs, outputs, masks, losses\n", "\n", @@ -261,7 +282,6 @@ "\n", "\n", "def process_prediction(image, img_recon, mask, rgb_index):\n", - " img_normed = minmax_norm(image.numpy())\n", "\n", " mask = process_mask(mask)\n", " \n", @@ -269,16 +289,19 @@ " blue_idx = rgb_index[1]\n", " green_idx = rgb_index[2]\n", "\n", - " rgb_image = np.stack((img_normed[red_idx, :, :],\n", - " img_normed[blue_idx, :, :],\n", - " img_normed[green_idx, :, :]),\n", - " axis=-1)\n", + " image = image.numpy()\n", + " rgb_image = np.stack((image[red_idx, :, :],\n", + " image[blue_idx, :, :],\n", + " image[green_idx, :, :]),\n", + " axis=-1)\n", + " rgb_image = minmax_norm(rgb_image)\n", "\n", - " img_recon = minmax_norm(img_recon.numpy())\n", + " img_recon = img_recon.numpy()\n", " rgb_image_recon = np.stack((img_recon[red_idx, :, :],\n", " img_recon[blue_idx, :, :],\n", " img_recon[green_idx, :, :]),\n", " axis=-1)\n", + " rgb_image_recon = minmax_norm(rgb_image_recon)\n", "\n", " rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon)\n", " rgb_image_masked = np.where(mask == 1, 0, rgb_image)\n", @@ -336,7 +359,7 @@ "source": [ "%%time\n", "\n", - "inputs, outputs, masks, losses = predict(model, dataloader, num_batches=5)" + "inputs, outputs, masks, losses = predict(model, dataloader, num_batches=64)" ] }, { @@ -354,9 +377,9 @@ "metadata": {}, "outputs": [], "source": [ - "pdf_path = '../../satvision-toa-reconstruction-pdf-02.20.pdf'\n", - "num_samples = 10 # Number of random samples from the predictions\n", - "rgb_index = [0, 3, 2] # Indices of [Red band, Blue band, Green band]\n", + "pdf_path = '../../satvision-toa-reconstruction-pdf-03.15.16patch.huge.001.pdf'\n", + "num_samples = 25 # Number of random samples from the predictions\n", + "rgb_index = [0, 2, 1] # Indices of [Red band, Blue band, Green band]\n", "\n", "plot_export_pdf(pdf_path, num_samples, inputs, outputs, masks, rgb_index)" ]