Skip to content

Commit

Permalink
Add option of caching image in bytes rather than float32
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye committed Jan 18, 2024
1 parent 15e81d3 commit b86b45c
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 31 deletions.
6 changes: 4 additions & 2 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class FullImageDatamanagerConfig(DataManagerConfig):
"""Specifies the image indices to use during eval; if None, uses all."""
cache_images: Literal["cpu", "gpu"] = "cpu"
"""Whether to cache images in memory. If "cpu", caches on cpu. If "gpu", caches on device."""
cache_images_type: Literal["uint8", "float32"] = "float32"
"""The image type returned from manager, caching images in uint8 saves memory"""


class FullImageDatamanager(DataManager, Generic[TDataset]):
Expand Down Expand Up @@ -126,7 +128,7 @@ def cache_images(self, cache_images_option):
CONSOLE.log("Caching / undistorting train images")
for i in tqdm(range(len(self.train_dataset)), leave=False):
# cv2.undistort the images / cameras
data = self.train_dataset.get_data(i)
data = self.train_dataset.get_data(i, image_type=self.config.cache_images_type)
camera = self.train_dataset.cameras[i].reshape(())
K = camera.get_intrinsics_matrices().numpy()
if camera.distortion_params is None:
Expand Down Expand Up @@ -201,7 +203,7 @@ def cache_images(self, cache_images_option):
CONSOLE.log("Caching / undistorting eval images")
for i in tqdm(range(len(self.eval_dataset)), leave=False):
# cv2.undistort the images / cameras
data = self.eval_dataset.get_data(i)
data = self.eval_dataset.get_data(i, image_type=self.config.cache_images_type)
camera = self.eval_dataset.cameras[i].reshape(())
K = camera.get_intrinsics_matrices().numpy()
if camera.distortion_params is None:
Expand Down
39 changes: 33 additions & 6 deletions nerfstudio/data/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

from copy import deepcopy
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Literal

import numpy as np
import numpy.typing as npt
import torch
from jaxtyping import Float
from jaxtyping import Float, UInt8
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
Expand Down Expand Up @@ -77,24 +77,51 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]:
assert image.shape[2] in [3, 4], f"Image shape of {image.shape} is in correct."
return image

def get_image(self, image_idx: int) -> Float[Tensor, "image_height image_width num_channels"]:
"""Returns a 3 channel image.
def get_image_float32(self, image_idx: int) -> Float[Tensor, "image_height image_width num_channels"]:
"""Returns a 3 channel image in float32 torch.Tensor.
Args:
image_idx: The image index in the dataset.
"""
image = torch.from_numpy(self.get_numpy_image(image_idx).astype("float32") / 255.0)
if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4:
assert (self._dataparser_outputs.alpha_color >= 0).all() and (
self._dataparser_outputs.alpha_color <= 1
).all(), "alpha color given is out of range between [0, 1]."
image = image[:, :, :3] * image[:, :, -1:] + self._dataparser_outputs.alpha_color * (1.0 - image[:, :, -1:])
return image

def get_data(self, image_idx: int) -> Dict:
def get_image_uint8(self, image_idx: int) -> UInt8[Tensor, "image_height image_width num_channels"]:
"""Returns a 3 channel image in uint8 torch.Tensor.
Args:
image_idx: The image index in the dataset.
"""
image = torch.from_numpy(self.get_numpy_image(image_idx))
if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4:
assert (self._dataparser_outputs.alpha_color >= 0).all() and (
self._dataparser_outputs.alpha_color <= 1
).all(), "alpha color given is out of range between [0, 1]."
image = image[:, :, :3] * image[:, :, -1:] / 255.0 + 255.0 * self._dataparser_outputs.alpha_color * (
1.0 - image[:, :, -1:] / 255.0
)
image = torch.clamp(image, min=0, max=255).to(torch.uint8)
return image

def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "float32") -> Dict:
"""Returns the ImageDataset data as a dictionary.
Args:
image_idx: The image index in the dataset.
image_type: the type of images returned
"""
image = self.get_image(image_idx)
if image_type == "float32":
image = self.get_image_float32(image_idx)
elif image_type == "uint8":
image = self.get_image_uint8(image_idx)
else:
raise NotImplementedError(f"image_type (={image_type}) getter was not implemented, use uint8 or float32")

data = {"image_idx": image_idx, "image": image}
if self._dataparser_outputs.mask_filenames is not None:
mask_filepath = self._dataparser_outputs.mask_filenames[image_idx]
Expand Down
42 changes: 20 additions & 22 deletions nerfstudio/models/gaussian_splatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,25 +724,35 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:

return {"rgb": rgb, "depth": depth_im} # type: ignore

def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
"""Compute and returns metrics.
def get_gt_img(self, image: torch.Tensor):
"""Compute groundtruth image with iteration dependent downscale factor for evaluation purpose
Args:
outputs: the output to compute loss dict to
batch: ground truth batch corresponding to outputs
image: tensor.Tensor in type uint8 or float32
"""
if image.dtype == torch.uint8:
image = image.float() / 255.0
d = self._get_downscale_factor()
if d > 1:
newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]
newsize = [image.shape[0] // d, image.shape[1] // d]

# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
gt_img = TF.resize(image.permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
else:
gt_img = batch["image"]
gt_img = image
return gt_img.to(self.device)

def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
"""Compute and returns metrics.
Args:
outputs: the output to compute loss dict to
batch: ground truth batch corresponding to outputs
"""
gt_rgb = self.get_gt_img(batch["image"])
metrics_dict = {}
gt_rgb = gt_img.to(self.device) # RGB or RGBA image
predicted_rgb = outputs["rgb"]
metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb)

Expand All @@ -758,16 +768,7 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
batch: ground truth batch corresponding to outputs
metrics_dict: dictionary of metrics, some of which we can use for loss
"""
d = self._get_downscale_factor()
if d > 1:
newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]

# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
else:
gt_img = batch["image"]
gt_img = self.get_gt_img(batch["image"])
Ll1 = torch.abs(gt_img - outputs["rgb"]).mean()
simloss = 1 - self.ssim(gt_img.permute(2, 0, 1)[None, ...], outputs["rgb"].permute(2, 0, 1)[None, ...])
if self.config.use_scale_regularization and self.step % 10 == 0:
Expand Down Expand Up @@ -814,20 +815,17 @@ def get_image_metrics_and_images(
Returns:
A dictionary of metrics.
"""
gt_rgb = self.get_gt_img(batch["image"])
d = self._get_downscale_factor()
if d > 1:
# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]
gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
predicted_rgb = TF.resize(outputs["rgb"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
else:
gt_img = batch["image"]
predicted_rgb = outputs["rgb"]

gt_rgb = gt_img.to(self.device)

combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1)

# Switch images from [H, W, C] to [1, C, H, W] for metrics computations
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/scripts/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _render_trajectory_video(
if render_nearest_camera:
assert train_dataset is not None
assert train_cameras is not None
img = train_dataset.get_image(max_idx)
img = train_dataset.get_image_float32(max_idx)
height = cameras.image_height[0]
# maintain the resolution of the img to calculate the width from the height
width = int(img.shape[1] * (height / img.shape[0]))
Expand Down

0 comments on commit b86b45c

Please sign in to comment.