Skip to content

Commit

Permalink
Encoder + MLP combo (#2063)
Browse files Browse the repository at this point in the history
* adding encoder + mlp combo

* update documentation

* minor fixes

* fixed issue when tcnn isn't installed

---------

Co-authored-by: Brent Yi <[email protected]>
  • Loading branch information
ethanweber and brentyi authored Dec 1, 2023
1 parent 0cb4100 commit 4c627ed
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 62 deletions.
88 changes: 62 additions & 26 deletions nerfstudio/field_components/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def __init__(self, in_dim: int) -> None:
raise ValueError("Input dimension should be greater than zero")
super().__init__(in_dim=in_dim)

@classmethod
def get_tcnn_encoding_config(cls) -> dict:
"""Get the encoding configuration for tcnn if implemented"""
raise NotImplementedError("Encoding does not have a TCNN implementation")

@abstractmethod
def forward(self, in_tensor: Shaped[Tensor, "*bs input_dim"]) -> Shaped[Tensor, "*bs output_dim"]:
"""Call forward and returns and processed tensor
Expand Down Expand Up @@ -126,14 +131,20 @@ def __init__(
if implementation == "tcnn" and not TCNN_EXISTS:
print_tcnn_speed_warning("NeRFEncoding")
elif implementation == "tcnn":
encoding_config = {"otype": "Frequency", "n_frequencies": num_frequencies}
assert min_freq_exp == 0, "tcnn only supports min_freq_exp = 0"
assert max_freq_exp == num_frequencies - 1, "tcnn only supports max_freq_exp = num_frequencies - 1"
encoding_config = self.get_tcnn_encoding_config(num_frequencies=self.num_frequencies)
self.tcnn_encoding = tcnn.Encoding(
n_input_dims=in_dim,
encoding_config=encoding_config,
)

@classmethod
def get_tcnn_encoding_config(cls, num_frequencies) -> dict:
"""Get the encoding configuration for tcnn if implemented"""
encoding_config = {"otype": "Frequency", "n_frequencies": num_frequencies}
return encoding_config

def get_out_dim(self) -> int:
if self.in_dim is None:
raise ValueError("Input dimension has not been set")
Expand Down Expand Up @@ -327,48 +338,67 @@ def __init__(
) -> None:
super().__init__(in_dim=3)
self.num_levels = num_levels
self.min_res = min_res
self.features_per_level = features_per_level
self.hash_init_scale = hash_init_scale
self.log2_hashmap_size = log2_hashmap_size
self.hash_table_size = 2**log2_hashmap_size

levels = torch.arange(num_levels)
growth_factor = np.exp((np.log(max_res) - np.log(min_res)) / (num_levels - 1)) if num_levels > 1 else 1
self.scalings = torch.floor(min_res * growth_factor**levels)
self.growth_factor = np.exp((np.log(max_res) - np.log(min_res)) / (num_levels - 1)) if num_levels > 1 else 1
self.scalings = torch.floor(min_res * self.growth_factor**levels)

self.hash_offset = levels * self.hash_table_size

self.tcnn_encoding = None
self.hash_table = torch.empty(0)
if implementation == "tcnn" and not TCNN_EXISTS:
if implementation == "torch":
self.build_nn_modules()
elif implementation == "tcnn" and not TCNN_EXISTS:
print_tcnn_speed_warning("HashEncoding")
implementation = "torch"

if implementation == "tcnn":
encoding_config = {
"otype": "HashGrid",
"n_levels": self.num_levels,
"n_features_per_level": self.features_per_level,
"log2_hashmap_size": self.log2_hashmap_size,
"base_resolution": min_res,
"per_level_scale": growth_factor,
}
if interpolation is not None:
encoding_config["interpolation"] = interpolation

self.build_nn_modules()
elif implementation == "tcnn":
encoding_config = self.get_tcnn_encoding_config(
num_levels=self.num_levels,
features_per_level=self.features_per_level,
log2_hashmap_size=self.log2_hashmap_size,
min_res=self.min_res,
growth_factor=self.growth_factor,
interpolation=interpolation,
)
self.tcnn_encoding = tcnn.Encoding(
n_input_dims=3,
encoding_config=encoding_config,
)
elif implementation == "torch":
self.hash_table = torch.rand(size=(self.hash_table_size * num_levels, features_per_level)) * 2 - 1
self.hash_table *= hash_init_scale
self.hash_table = nn.Parameter(self.hash_table)

if self.tcnn_encoding is None:
assert (
interpolation is None or interpolation == "Linear"
), f"interpolation '{interpolation}' is not supported for torch encoding backend"

def build_nn_modules(self) -> None:
"""Initialize the torch version of the hash encoding."""
self.hash_table = torch.rand(size=(self.hash_table_size * self.num_levels, self.features_per_level)) * 2 - 1
self.hash_table *= self.hash_init_scale
self.hash_table = nn.Parameter(self.hash_table)

@classmethod
def get_tcnn_encoding_config(
cls, num_levels, features_per_level, log2_hashmap_size, min_res, growth_factor, interpolation=None
) -> dict:
"""Get the encoding configuration for tcnn if implemented"""
encoding_config = {
"otype": "HashGrid",
"n_levels": num_levels,
"n_features_per_level": features_per_level,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": min_res,
"per_level_scale": growth_factor,
}
if interpolation is not None:
encoding_config["interpolation"] = interpolation
return encoding_config

def get_out_dim(self) -> int:
return self.num_levels * self.features_per_level

Expand Down Expand Up @@ -745,15 +775,21 @@ def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "
if implementation == "tcnn" and not TCNN_EXISTS:
print_tcnn_speed_warning("SHEncoding")
elif implementation == "tcnn":
encoding_config = {
"otype": "SphericalHarmonics",
"degree": levels,
}
encoding_config = self.get_tcnn_encoding_config(levels=self.levels)
self.tcnn_encoding = tcnn.Encoding(
n_input_dims=3,
encoding_config=encoding_config,
)

@classmethod
def get_tcnn_encoding_config(cls, levels) -> dict:
"""Get the encoding configuration for tcnn if implemented"""
encoding_config = {
"otype": "SphericalHarmonics",
"degree": levels,
}
return encoding_config

def get_out_dim(self) -> int:
return self.levels**2

Expand Down
176 changes: 149 additions & 27 deletions nerfstudio/field_components/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
"""
from typing import Literal, Optional, Set, Tuple, Union

import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor, nn

from nerfstudio.field_components.base_field_component import FieldComponent
from nerfstudio.utils.printing import print_tcnn_speed_warning
from nerfstudio.field_components.encodings import HashEncoding

from nerfstudio.utils.rich_utils import CONSOLE
from nerfstudio.utils.external import TCNN_EXISTS, tcnn
Expand Down Expand Up @@ -66,6 +68,7 @@ class MLP(FieldComponent):
out_dim: Output layer dimension. Uses layer_width if None.
activation: intermediate layer activation function.
out_activation: output activation function.
implementation: Implementation of hash encoding. Fallback to torch if tcnn not available.
"""

def __init__(
Expand Down Expand Up @@ -98,39 +101,47 @@ def __init__(
print_tcnn_speed_warning("MLP")
self.build_nn_modules()
elif implementation == "tcnn":
activation_str = activation_to_tcnn_string(activation)
output_activation_str = activation_to_tcnn_string(out_activation)
if layer_width in [16, 32, 64, 128]:
network_config = {
"otype": "FullyFusedMLP",
"activation": activation_str,
"output_activation": output_activation_str,
"n_neurons": layer_width,
"n_hidden_layers": num_layers - 1,
}
else:
CONSOLE.line()
CONSOLE.print("[bold yellow]WARNING: Using slower TCNN CutlassMLP instead of TCNN FullyFusedMLP")
CONSOLE.print(
"[bold yellow]Use layer width of 16, 32, 64, or 128 to use the faster TCNN FullyFusedMLP."
)
CONSOLE.line()
network_config = {
"otype": "CutlassMLP",
"activation": activation_str,
"output_activation": output_activation_str,
"n_neurons": layer_width,
"n_hidden_layers": num_layers - 1,
}

network_config = self.get_tcnn_network_config(
activation=self.activation,
out_activation=self.out_activation,
layer_width=self.layer_width,
num_layers=self.num_layers,
)
self.tcnn_encoding = tcnn.Network(
n_input_dims=in_dim,
n_output_dims=out_dim,
n_output_dims=self.out_dim,
network_config=network_config,
)

@classmethod
def get_tcnn_network_config(cls, activation, out_activation, layer_width, num_layers) -> dict:
"""Get the network configuration for tcnn if implemented"""
activation_str = activation_to_tcnn_string(activation)
output_activation_str = activation_to_tcnn_string(out_activation)
if layer_width in [16, 32, 64, 128]:
network_config = {
"otype": "FullyFusedMLP",
"activation": activation_str,
"output_activation": output_activation_str,
"n_neurons": layer_width,
"n_hidden_layers": num_layers - 1,
}
else:
CONSOLE.line()
CONSOLE.print("[bold yellow]WARNING: Using slower TCNN CutlassMLP instead of TCNN FullyFusedMLP")
CONSOLE.print("[bold yellow]Use layer width of 16, 32, 64, or 128 to use the faster TCNN FullyFusedMLP.")
CONSOLE.line()
network_config = {
"otype": "CutlassMLP",
"activation": activation_str,
"output_activation": output_activation_str,
"n_neurons": layer_width,
"n_hidden_layers": num_layers - 1,
}
return network_config

def build_nn_modules(self) -> None:
"""Initialize multi-layer perceptron."""
"""Initialize the torch version of the multi-layer perceptron."""
layers = []
if self.num_layers == 1:
layers.append(nn.Linear(self.in_dim, self.out_dim))
Expand Down Expand Up @@ -171,3 +182,114 @@ def forward(self, in_tensor: Float[Tensor, "*bs in_dim"]) -> Float[Tensor, "*bs
if self.tcnn_encoding is not None:
return self.tcnn_encoding(in_tensor)
return self.pytorch_fwd(in_tensor)


class MLPWithHashEncoding(FieldComponent):
"""Multilayer perceptron with hash encoding
Args:
num_levels: Number of feature grids.
min_res: Resolution of smallest feature grid.
max_res: Resolution of largest feature grid.
log2_hashmap_size: Size of hash map is 2^log2_hashmap_size.
features_per_level: Number of features per level.
hash_init_scale: Value to initialize hash grid.
interpolation: Interpolation override for tcnn hashgrid. Not supported for torch unless linear.
num_layers: Number of network layers
layer_width: Width of each MLP layer
out_dim: Output layer dimension. Uses layer_width if None.
activation: intermediate layer activation function.
out_activation: output activation function.
implementation: Implementation of hash encoding. Fallback to torch if tcnn not available.
"""

def __init__(
self,
num_levels: int = 16,
min_res: int = 16,
max_res: int = 1024,
log2_hashmap_size: int = 19,
features_per_level: int = 2,
hash_init_scale: float = 0.001,
interpolation: Optional[Literal["Nearest", "Linear", "Smoothstep"]] = None,
num_layers: int = 2,
layer_width: int = 64,
out_dim: Optional[int] = None,
skip_connections: Optional[Tuple[int]] = None,
activation: Optional[nn.Module] = nn.ReLU(),
out_activation: Optional[nn.Module] = None,
implementation: Literal["tcnn", "torch"] = "torch",
) -> None:
super().__init__()
self.in_dim = 3

self.num_levels = num_levels
self.min_res = min_res
self.max_res = max_res
self.features_per_level = features_per_level
self.hash_init_scale = hash_init_scale
self.log2_hashmap_size = log2_hashmap_size
self.hash_table_size = 2**log2_hashmap_size

self.growth_factor = np.exp((np.log(max_res) - np.log(min_res)) / (num_levels - 1)) if num_levels > 1 else 1

self.out_dim = out_dim if out_dim is not None else layer_width
self.num_layers = num_layers
self.layer_width = layer_width
self.skip_connections = skip_connections
self._skip_connections: Set[int] = set(skip_connections) if skip_connections else set()
self.activation = activation
self.out_activation = out_activation
self.net = None

self.tcnn_encoding = None
if implementation == "torch":
self.build_nn_modules()
elif implementation == "tcnn" and not TCNN_EXISTS:
print_tcnn_speed_warning("MLPWithHashEncoding")
self.build_nn_modules()
elif implementation == "tcnn":
self.model = tcnn.NetworkWithInputEncoding(
n_input_dims=self.in_dim,
n_output_dims=self.out_dim,
encoding_config=HashEncoding.get_tcnn_encoding_config(
num_levels=self.num_levels,
features_per_level=self.features_per_level,
log2_hashmap_size=self.log2_hashmap_size,
min_res=self.min_res,
growth_factor=self.growth_factor,
interpolation=interpolation,
),
network_config=MLP.get_tcnn_network_config(
activation=self.activation,
out_activation=self.out_activation,
layer_width=self.layer_width,
num_layers=self.num_layers,
),
)

def build_nn_modules(self) -> None:
"""Initialize the torch version of the MLP with hash encoding."""
encoder = HashEncoding(
num_levels=self.num_levels,
min_res=self.min_res,
max_res=self.max_res,
log2_hashmap_size=self.log2_hashmap_size,
features_per_level=self.features_per_level,
hash_init_scale=self.hash_init_scale,
implementation="torch",
)
mlp = MLP(
in_dim=encoder.get_out_dim(),
num_layers=self.num_layers,
layer_width=self.layer_width,
out_dim=self.out_dim,
skip_connections=self.skip_connections,
activation=self.activation,
out_activation=self.out_activation,
implementation="torch",
)
self.model = torch.nn.Sequential(encoder, mlp)

def forward(self, in_tensor: Float[Tensor, "*bs in_dim"]) -> Float[Tensor, "*bs out_dim"]:
return self.model(in_tensor)
Loading

0 comments on commit 4c627ed

Please sign in to comment.