Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #54

Merged
merged 3 commits into from
May 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 109 additions & 68 deletions notebooks/satvision-toa-reconstruction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"source": [
"# Satvision-TOA Reconstruction Notebook\n",
"\n",
"Version: 03.15.24\n",
"Version: 04.30.24\n",
"\n",
"Env: `Python [conda env:ilab-pytorch]`"
]
Expand All @@ -34,6 +34,7 @@
"import time\n",
"import random\n",
"import datetime\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import logging\n",
"\n",
Expand Down Expand Up @@ -75,7 +76,7 @@
"id": "d841e464-f880-4e53-bf31-f9f225713918",
"metadata": {},
"source": [
"## Configuration"
"## 1. Configuration"
]
},
{
Expand All @@ -91,20 +92,27 @@
"\n",
"git lfs install\n",
"\n",
"git clone [email protected]:nasa-cisto-data-science-group/satvision-toa-huge\n",
"git clone git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-huge-patch8-window12-192\n",
"```\n",
"\n",
"Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth. \n",
"Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth.\n",
"https://huggingface.co/docs/hub/security-git-ssh\n",
"\n",
"If experiencing ssh-related authentication issues:\n",
"```bash\n",
"eval `ssh-agent -s` # starts ssh-agent\n",
"eval $(ssh-agent)\n",
"\n",
"ssh-add -l # is your ssh key added to the agent?\n",
"# If this outputs as anon, follow the next steps.\n",
"ssh -T [email protected]\n",
"\n",
"ssh-add ~/.ssh/id_xxxx # adds ssh ID to ssh-agent\n",
"# Check if ssh-agent is using the proper key\n",
"ssh-add -l\n",
"\n",
"# If not\n",
"ssh-add ~/.ssh/your-key\n",
"\n",
"# Or if you want to use the default id_* key, just do\n",
"ssh-add\n",
"\n",
"ssh -T [email protected] # Should return \"Hi <user-id>, welcome to Hugging Face.\"\n",
"```"
]
},
Expand All @@ -115,13 +123,12 @@
"metadata": {},
"outputs": [],
"source": [
"MODEL_PATH: str = '../../satvision-toa-huge/ckpt_epoch_100.pth'\n",
"CONFIG_PATH: str = '../../satvision-toa-huge/mim_pretrain_swinv2_satvision_huge_192_window12_200ep.yaml'\n",
"MODEL_PATH: str = '../../satvision-toa-huge-patch8-window12-192/mp_rank_00_model_states.pt'\n",
"CONFIG_PATH: str = '../../satvision-toa-huge-patch8-window12-192/mim_pretrain_swinv2_satvision_huge_192_window12_100ep.yaml'\n",
"\n",
"BATCH_SIZE: int = 1 # Want to report loss on every image? Change to 1.\n",
"OUTPUT: str = '.'\n",
"TAG: str = 'satvision-base-toa-reconstruction'\n",
"DATA_PATH: str = '/explore/nobackup/projects/ilab/projects/3DClouds/data/mosaic-v3/webdatasets'\n",
"TAG: str = 'satvision-huge-toa-reconstruction'\n",
"DATA_PATH: str = '/explore/nobackup/projects/ilab/projects/3DClouds/data/validation/sv_toa_128_chip_validation_04_24.npy'\n",
"DATA_PATHS: list = [DATA_PATH]"
]
},
Expand All @@ -140,7 +147,6 @@
"config.defrost()\n",
"config.MODEL.RESUME = MODEL_PATH\n",
"config.DATA.DATA_PATHS = DATA_PATHS\n",
"config.DATA.BATCH_SIZE = BATCH_SIZE\n",
"config.OUTPUT = OUTPUT\n",
"config.TAG = TAG\n",
"config.freeze()"
Expand Down Expand Up @@ -169,6 +175,14 @@
"logger.addHandler(console)"
]
},
{
"cell_type": "markdown",
"id": "11ebd497-7741-41a7-af9d-0ee49a6313a4",
"metadata": {},
"source": [
"## 2. Load model weights from checkpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -178,7 +192,7 @@
"source": [
"checkpoint = torch.load(MODEL_PATH)\n",
"model = build_model(config, pretrain=True)\n",
"model.load_state_dict(checkpoint['model'])\n",
"model.load_state_dict(checkpoint['module']) # If 'module' not working, try 'model'\n",
"n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"logger.info(f\"number of params: {n_parameters}\")\n",
"model.cuda()\n",
Expand All @@ -187,40 +201,44 @@
},
{
"cell_type": "markdown",
"id": "bd9ba52e-62ca-4800-b2aa-deaaea64be9f",
"id": "b500d13b-89d7-4cd8-a36a-ab6f10f6a397",
"metadata": {},
"source": [
"## Dataloader"
"## 3. Load evaluation set (from numpy file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75c52c66-f322-413c-be76-6c7abfd159bc",
"id": "73a8d307-de9b-4617-abdd-dae1e7c2521a",
"metadata": {},
"outputs": [],
"source": [
"dataset = MODIS22MDataset(config,\n",
" config.DATA.DATA_PATHS,\n",
" split=\"train\",\n",
" img_size=config.DATA.IMG_SIZE,\n",
" transform=SimmimTransform(config),\n",
" batch_size=config.DATA.BATCH_SIZE).dataset()\n",
"\n",
"dataloader = torch.utils.data.DataLoader(\n",
" dataset,\n",
" batch_size=None, # Change if not using webdataset as underlying dataset type\n",
" num_workers=15,\n",
" shuffle=False,\n",
" pin_memory=True,)"
"# Use the Masked-Image-Modeling transform\n",
"transform = SimmimTransform(config)\n",
"\n",
"# The reconstruction evaluation set is a single numpy file\n",
"validation_dataset_path = config.DATA.DATA_PATHS[0]\n",
"validation_dataset = np.load(validation_dataset_path)\n",
"len_batch = range(validation_dataset.shape[0])\n",
"\n",
"# Apply transform to each image in the batch\n",
"# A mask is auto-generated in the transform\n",
"imgMasks = [transform(validation_dataset[idx]) for idx \\\n",
" in len_batch]\n",
"\n",
"# Seperate img and masks, cast masks to torch tensor\n",
"img = torch.stack([imgMask[0] for imgMask in imgMasks])\n",
"mask = torch.stack([torch.from_numpy(imgMask[1]) for \\\n",
" imgMask in imgMasks])"
]
},
{
"cell_type": "markdown",
"id": "55acf5e9-eb2a-496c-baa6-3b74503a2978",
"metadata": {},
"source": [
"## Prediction helper functions"
"## 4. Prediction helper functions"
]
},
{
Expand All @@ -236,30 +254,33 @@
" outputs = []\n",
" masks = []\n",
" losses = []\n",
" with tqdm(total=num_batches) as pbar:\n",
"\n",
" for idx, img_mask in enumerate(dataloader):\n",
" \n",
" if idx > num_batches:\n",
" return inputs, outputs, masks, losses\n",
" for idx, img_mask in enumerate(dataloader):\n",
" \n",
" pbar.update(1)\n",
"\n",
" img_mask = img_mask[0]\n",
" if idx > num_batches:\n",
" return inputs, outputs, masks, losses\n",
"\n",
" img = torch.stack([pair[0] for pair in img_mask])\n",
" mask = torch.stack([pair[1] for pair in img_mask])\n",
" img_mask = img_mask[0]\n",
"\n",
" img = img.cuda(non_blocking=True)\n",
" mask = mask.cuda(non_blocking=True)\n",
" img = torch.stack([pair[0] for pair in img_mask])\n",
" mask = torch.stack([pair[1] for pair in img_mask])\n",
"\n",
" with torch.no_grad():\n",
" with amp.autocast(enabled=config.ENABLE_AMP):\n",
" z = model.encoder(img, mask)\n",
" img_recon = model.decoder(z)\n",
" loss = model(img, mask)\n",
" img = img.cuda(non_blocking=True)\n",
" mask = mask.cuda(non_blocking=True)\n",
"\n",
" inputs.extend(img.cpu())\n",
" masks.extend(mask.cpu())\n",
" outputs.extend(img_recon.cpu())\n",
" losses.append(loss.cpu())\n",
" with torch.no_grad():\n",
" with amp.autocast(enabled=config.ENABLE_AMP):\n",
" z = model.encoder(img, mask)\n",
" img_recon = model.decoder(z)\n",
" loss = model(img, mask)\n",
"\n",
" inputs.extend(img.cpu())\n",
" masks.extend(mask.cpu())\n",
" outputs.extend(img_recon.cpu())\n",
" losses.append(loss.cpu())\n",
" \n",
" return inputs, outputs, masks, losses\n",
"\n",
Expand All @@ -274,11 +295,11 @@
"\n",
"\n",
"def process_mask(mask):\n",
" mask = mask.unsqueeze(0)\n",
" mask = mask.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous()\n",
" mask = mask[0, 0, :, :]\n",
" mask = np.stack([mask, mask, mask], axis=-1)\n",
" return mask\n",
" mask_img = mask.unsqueeze(0)\n",
" mask_img = mask_img.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous()\n",
" mask_img = mask_img[0, 0, :, :]\n",
" mask_img = np.stack([mask_img, mask_img, mask_img], axis=-1)\n",
" return mask_img\n",
"\n",
"\n",
"def process_prediction(image, img_recon, mask, rgb_index):\n",
Expand Down Expand Up @@ -310,11 +331,10 @@
" return rgb_image, rgb_image_masked, rgb_recon_masked, mask\n",
"\n",
"\n",
"def plot_export_pdf(path, num_sample, inputs, outputs, masks, rgb_index):\n",
" random_subsample = random.sample(range(len(inputs)), num_sample)\n",
"def plot_export_pdf(path, inputs, outputs, masks, rgb_index):\n",
" pdf_plot_obj = PdfPages(path)\n",
"\n",
" for idx in random_subsample:\n",
" for idx in range(len(inputs)):\n",
" # prediction processing\n",
" image = inputs[idx]\n",
" img_recon = outputs[idx]\n",
Expand Down Expand Up @@ -347,27 +367,49 @@
"id": "551c44b5-6d88-45c4-b397-c38de8064544",
"metadata": {},
"source": [
"## Predict"
"## 5. Predict"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa43bfaf-6379-43d5-9be3-bd0e55f5ca12",
"id": "4e695cc3-b869-4fc2-b360-b45f3b81affd",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"inputs, outputs, masks, losses = predict(model, dataloader, num_batches=64)"
"inputs = []\n",
"outputs = []\n",
"masks = []\n",
"losses = []\n",
"\n",
"# We could do this in a single batch however we\n",
"# want to report the loss per-image, in place of\n",
"# loss per-batch.\n",
"for i in tqdm(range(img.shape[0])):\n",
" single_img = img[i].unsqueeze(0)\n",
" single_mask = mask[i].unsqueeze(0)\n",
" single_img = single_img.cuda(non_blocking=True)\n",
" single_mask = single_mask.cuda(non_blocking=True)\n",
"\n",
" with torch.no_grad():\n",
" z = model.encoder(single_img, single_mask)\n",
" img_recon = model.decoder(z)\n",
" loss = model(single_img, single_mask)\n",
"\n",
" inputs.extend(single_img.cpu())\n",
" masks.extend(single_mask.cpu())\n",
" outputs.extend(img_recon.cpu())\n",
" losses.append(loss.cpu()) "
]
},
{
"cell_type": "markdown",
"id": "dc3f102c-94df-4d9e-8040-52197a7e71db",
"metadata": {},
"source": [
"## Plot and write to PDF"
"## 6. Plot and write to PDF\n",
"\n",
"Writes out all of the predictions to a PDF file"
]
},
{
Expand All @@ -377,11 +419,10 @@
"metadata": {},
"outputs": [],
"source": [
"pdf_path = '../../satvision-toa-reconstruction-pdf-03.15.16patch.huge.001.pdf'\n",
"num_samples = 25 # Number of random samples from the predictions\n",
"pdf_path = '../../satvision-toa-reconstruction-pdf-huge-patch-8-04.30.pdf'\n",
"rgb_index = [0, 2, 1] # Indices of [Red band, Blue band, Green band]\n",
"\n",
"plot_export_pdf(pdf_path, num_samples, inputs, outputs, masks, rgb_index)"
"plot_export_pdf(pdf_path, inputs, outputs, masks, rgb_index)"
]
},
{
Expand Down
Loading