Skip to content

Commit

Permalink
updated notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
cssprad1 committed Mar 15, 2024
1 parent f6a5741 commit 4e822b2
Showing 1 changed file with 44 additions and 21 deletions.
65 changes: 44 additions & 21 deletions notebooks/satvision-toa-reconstruction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"source": [
"# Satvision-TOA Reconstruction Notebook\n",
"\n",
"Version: 02.20.24\n",
"Version: 03.15.24\n",
"\n",
"Env: `Python [conda env:ilab-pytorch]`"
]
Expand Down Expand Up @@ -59,13 +59,11 @@
"\n",
"from pytorch_caney.config import get_config\n",
"\n",
"from pytorch_caney.training.mim_utils import load_checkpoint, load_pretrained\n",
"\n",
"from pytorch_caney.models.build import build_model\n",
"\n",
"from pytorch_caney.ptc_logging import create_logger\n",
"\n",
"from pytorch_caney.data.datamodules import mim_webdataset_datamodule\n",
"from pytorch_caney.data.datasets.mim_modis_22m_dataset import MODIS22MDataset\n",
"\n",
"from pytorch_caney.data.transforms import SimmimTransform, SimmimMaskGenerator\n",
"\n",
Expand Down Expand Up @@ -93,10 +91,21 @@
"\n",
"git lfs install\n",
"\n",
"git clone [email protected]:nasa-cisto-data-science-group/satvision-toa-base\n",
"git clone [email protected]:nasa-cisto-data-science-group/satvision-toa-huge\n",
"```\n",
"\n",
"Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth. "
"Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth. \n",
"\n",
"If experiencing ssh-related authentication issues:\n",
"```bash\n",
"eval `ssh-agent -s` # starts ssh-agent\n",
"\n",
"ssh-add -l # is your ssh key added to the agent?\n",
"\n",
"ssh-add ~/.ssh/id_xxxx # adds ssh ID to ssh-agent\n",
"\n",
"ssh -T [email protected] # Should return \"Hi <user-id>, welcome to Hugging Face.\"\n",
"```"
]
},
{
Expand All @@ -106,10 +115,10 @@
"metadata": {},
"outputs": [],
"source": [
"MODEL_PATH: str = '../../satvision-toa-base/satvision-toa_84M_2M_100.pth'\n",
"CONFIG_PATH: str = '../../satvision-toa-base/mim_pretrain_swinv2_satvision-toa_base_192_window12_800ep.yaml'\n",
"MODEL_PATH: str = '../../satvision-toa-huge/ckpt_epoch_100.pth'\n",
"CONFIG_PATH: str = '../../satvision-toa-huge/mim_pretrain_swinv2_satvision_huge_192_window12_200ep.yaml'\n",
"\n",
"BATCH_SIZE: int = 64 # Want to report loss on every image? Change to 1.\n",
"BATCH_SIZE: int = 1 # Want to report loss on every image? Change to 1.\n",
"OUTPUT: str = '.'\n",
"TAG: str = 'satvision-base-toa-reconstruction'\n",
"DATA_PATH: str = '/explore/nobackup/projects/ilab/projects/3DClouds/data/mosaic-v3/webdatasets'\n",
Expand Down Expand Up @@ -191,7 +200,19 @@
"metadata": {},
"outputs": [],
"source": [
"dataloader = mim_webdataset_datamodule.build_mim_dataloader(config, logger)"
"dataset = MODIS22MDataset(config,\n",
" config.DATA.DATA_PATHS,\n",
" split=\"train\",\n",
" img_size=config.DATA.IMG_SIZE,\n",
" transform=SimmimTransform(config),\n",
" batch_size=config.DATA.BATCH_SIZE).dataset()\n",
"\n",
"dataloader = torch.utils.data.DataLoader(\n",
" dataset,\n",
" batch_size=None, # Change if not using webdataset as underlying dataset type\n",
" num_workers=15,\n",
" shuffle=False,\n",
" pin_memory=True,)"
]
},
{
Expand Down Expand Up @@ -238,7 +259,7 @@
" inputs.extend(img.cpu())\n",
" masks.extend(mask.cpu())\n",
" outputs.extend(img_recon.cpu())\n",
" losses.append(losses)\n",
" losses.append(loss.cpu())\n",
" \n",
" return inputs, outputs, masks, losses\n",
"\n",
Expand All @@ -261,24 +282,26 @@
"\n",
"\n",
"def process_prediction(image, img_recon, mask, rgb_index):\n",
" img_normed = minmax_norm(image.numpy())\n",
"\n",
" mask = process_mask(mask)\n",
" \n",
" red_idx = rgb_index[0]\n",
" blue_idx = rgb_index[1]\n",
" green_idx = rgb_index[2]\n",
"\n",
" rgb_image = np.stack((img_normed[red_idx, :, :],\n",
" img_normed[blue_idx, :, :],\n",
" img_normed[green_idx, :, :]),\n",
" axis=-1)\n",
" image = image.numpy()\n",
" rgb_image = np.stack((image[red_idx, :, :],\n",
" image[blue_idx, :, :],\n",
" image[green_idx, :, :]),\n",
" axis=-1)\n",
" rgb_image = minmax_norm(rgb_image)\n",
"\n",
" img_recon = minmax_norm(img_recon.numpy())\n",
" img_recon = img_recon.numpy()\n",
" rgb_image_recon = np.stack((img_recon[red_idx, :, :],\n",
" img_recon[blue_idx, :, :],\n",
" img_recon[green_idx, :, :]),\n",
" axis=-1)\n",
" rgb_image_recon = minmax_norm(rgb_image_recon)\n",
"\n",
" rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon)\n",
" rgb_image_masked = np.where(mask == 1, 0, rgb_image)\n",
Expand Down Expand Up @@ -336,7 +359,7 @@
"source": [
"%%time\n",
"\n",
"inputs, outputs, masks, losses = predict(model, dataloader, num_batches=5)"
"inputs, outputs, masks, losses = predict(model, dataloader, num_batches=64)"
]
},
{
Expand All @@ -354,9 +377,9 @@
"metadata": {},
"outputs": [],
"source": [
"pdf_path = '../../satvision-toa-reconstruction-pdf-02.20.pdf'\n",
"num_samples = 10 # Number of random samples from the predictions\n",
"rgb_index = [0, 3, 2] # Indices of [Red band, Blue band, Green band]\n",
"pdf_path = '../../satvision-toa-reconstruction-pdf-03.15.16patch.huge.001.pdf'\n",
"num_samples = 25 # Number of random samples from the predictions\n",
"rgb_index = [0, 2, 1] # Indices of [Red band, Blue band, Green band]\n",
"\n",
"plot_export_pdf(pdf_path, num_samples, inputs, outputs, masks, rgb_index)"
]
Expand Down

0 comments on commit 4e822b2

Please sign in to comment.