From a60784b0c54e15c011e7105ba06f160abf51fd86 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 9 Apr 2021 12:35:44 -0700 Subject: [PATCH] feed z into every layer of mod network --- README.md | 6 +++--- setup.py | 2 +- siren_pytorch/siren_pytorch.py | 22 +++++++++------------- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 5999be0..4b16f46 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,7 @@ pred_img = wrapper() # (1, 3, 256, 256) A new paper proposes that the best way to condition a Siren with a latent code is to pass the latent vector through a modulator feedforward network, where each layer's hidden state is elementwise multiplied with the corresponding layer of the Siren. -You can use this simply by setting an extra keyword `latent_dim`, on the `SirenWrapper` +You can use this simply by setting an extra keyword `use_latent`, on the `SirenWrapper` ```python import torch @@ -111,12 +111,12 @@ net = SirenNet( wrapper = SirenWrapper( net, - latent_dim = 512, + use_latent = True, image_width = 256, image_height = 256 ) -latent = torch.randn(512) +latent = nn.Parameter(torch.zeros(256).normal_(0, 1e-2)) img = torch.randn(1, 3, 256, 256) loss = wrapper(img, latent = latent) diff --git a/setup.py b/setup.py index fca7192..d37708f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'siren-pytorch', packages = find_packages(), - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'Implicit Neural Representations with Periodic Activation Functions', author = 'Phil Wang', diff --git a/siren_pytorch/siren_pytorch.py b/siren_pytorch/siren_pytorch.py index e461c0a..c69bdbc 100644 --- a/siren_pytorch/siren_pytorch.py +++ b/siren_pytorch/siren_pytorch.py @@ -83,32 +83,29 @@ def forward(self, x, mods = None): x = layer(x) if exists(mod): - x *= rearrange(mod, 'd -> () d') + x *= rearrange(mod, 'd -> () d').sigmoid() return self.last_layer(x) # modulatory feed forward class Modulator(nn.Module): - def __init__(self, dim_in, dim_hidden, num_layers): + def __init__(self, dim, num_layers): super().__init__() - dim = dim_in self.layers = nn.ModuleList([]) for ind in range(num_layers): - is_first = ind == 0 - dim = dim_in if is_first else dim_hidden - self.layers.append(nn.Sequential( - nn.Linear(dim, dim_hidden), + nn.Linear(dim, dim), nn.ReLU() )) - def forward(self, x): + def forward(self, z): + x = z hiddens = [] for layer in self.layers: - x = layer(x) + x = layer(x + z) hiddens.append(x) return tuple(hiddens) @@ -116,7 +113,7 @@ def forward(self, x): # wrapper class SirenWrapper(nn.Module): - def __init__(self, net, image_width, image_height, latent_dim = None): + def __init__(self, net, image_width, image_height, use_latent = False): super().__init__() assert isinstance(net, SirenNet), 'SirenWrapper must receive a Siren network' @@ -125,10 +122,9 @@ def __init__(self, net, image_width, image_height, latent_dim = None): self.image_height = image_height self.modulator = None - if exists(latent_dim): + if use_latent: self.modulator = Modulator( - dim_in = latent_dim, - dim_hidden = net.dim_hidden, + dim = net.dim_hidden, num_layers = net.num_layers )