Skip to content

Commit

Permalink
feed z into every layer of mod network
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 9, 2021
1 parent 77b6b14 commit a60784b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 17 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ pred_img = wrapper() # (1, 3, 256, 256)

A <a href="https://arxiv.org/abs/2104.03960">new paper</a> proposes that the best way to condition a Siren with a latent code is to pass the latent vector through a modulator feedforward network, where each layer's hidden state is elementwise multiplied with the corresponding layer of the Siren.

You can use this simply by setting an extra keyword `latent_dim`, on the `SirenWrapper`
You can use this simply by setting an extra keyword `use_latent`, on the `SirenWrapper`

```python
import torch
Expand All @@ -111,12 +111,12 @@ net = SirenNet(

wrapper = SirenWrapper(
net,
latent_dim = 512,
use_latent = True,
image_width = 256,
image_height = 256
)

latent = torch.randn(512)
latent = nn.Parameter(torch.zeros(256).normal_(0, 1e-2))
img = torch.randn(1, 3, 256, 256)

loss = wrapper(img, latent = latent)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'siren-pytorch',
packages = find_packages(),
version = '0.1.1',
version = '0.1.2',
license='MIT',
description = 'Implicit Neural Representations with Periodic Activation Functions',
author = 'Phil Wang',
Expand Down
22 changes: 9 additions & 13 deletions siren_pytorch/siren_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,40 +83,37 @@ def forward(self, x, mods = None):
x = layer(x)

if exists(mod):
x *= rearrange(mod, 'd -> () d')
x *= rearrange(mod, 'd -> () d').sigmoid()

return self.last_layer(x)

# modulatory feed forward

class Modulator(nn.Module):
def __init__(self, dim_in, dim_hidden, num_layers):
def __init__(self, dim, num_layers):
super().__init__()
dim = dim_in
self.layers = nn.ModuleList([])

for ind in range(num_layers):
is_first = ind == 0
dim = dim_in if is_first else dim_hidden

self.layers.append(nn.Sequential(
nn.Linear(dim, dim_hidden),
nn.Linear(dim, dim),
nn.ReLU()
))

def forward(self, x):
def forward(self, z):
x = z
hiddens = []

for layer in self.layers:
x = layer(x)
x = layer(x + z)
hiddens.append(x)

return tuple(hiddens)

# wrapper

class SirenWrapper(nn.Module):
def __init__(self, net, image_width, image_height, latent_dim = None):
def __init__(self, net, image_width, image_height, use_latent = False):
super().__init__()
assert isinstance(net, SirenNet), 'SirenWrapper must receive a Siren network'

Expand All @@ -125,10 +122,9 @@ def __init__(self, net, image_width, image_height, latent_dim = None):
self.image_height = image_height

self.modulator = None
if exists(latent_dim):
if use_latent:
self.modulator = Modulator(
dim_in = latent_dim,
dim_hidden = net.dim_hidden,
dim = net.dim_hidden,
num_layers = net.num_layers
)

Expand Down

0 comments on commit a60784b

Please sign in to comment.