Skip to content

Commit

Permalink
correct initialization for first layer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 21, 2020
1 parent 7cdfcae commit 5160b9d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ net = SirenNet(
dim_hidden = 256, # hidden dimension
dim_out = 3, # output dimension, ex. rgb value
num_layers = 5, # number of layers
final_activation = nn.Sigmoid() # activation of final layer
final_activation = nn.Sigmoid(), # activation of final layer (nn.Identity() for direct output)
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)

coor = torch.randn(1, 2)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'siren-pytorch',
packages = find_packages(),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'Implicit Neural Representations with Periodic Activation Functions',
author = 'Phil Wang',
Expand Down
38 changes: 20 additions & 18 deletions siren_pytorch/siren_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,6 @@
from torch import nn
import torch.nn.functional as F

# siren initialization

def init_(weight, bias, c = 6., w0 = 1.):
dim = weight.size(1)
std = 1 / math.sqrt(dim)

w_std = math.sqrt(c) * std / w0
weight.uniform_(-w_std, w_std)

if bias is not None:
bias.uniform_(-std, std)

# sin activation

class Sine(nn.Module):
Expand All @@ -27,16 +15,28 @@ def forward(self, x):
# siren layer

class Siren(nn.Module):
def __init__(self, dim_in, dim_out, w0 = 1., c = 6., use_bias = True, activation = None):
def __init__(self, dim_in, dim_out, w0 = 30., c = 6., is_first = False, use_bias = True, activation = None):
super().__init__()
self.dim_in = dim_in
self.is_first = is_first

weight = torch.zeros(dim_out, dim_in)
bias = torch.zeros(dim_out) if use_bias else None
init_(weight, bias, c = c, w0 = w0)
self.init_(weight, bias, c = c, w0 = w0)

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias) if use_bias else None
self.activation = Sine(w0) if activation is None else activation

def init_(self, weight, bias, c, w0):
dim = self.dim_in

w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
weight.uniform_(-w_std, w_std)

if bias is not None:
bias.uniform_(-w_std, w_std)

def forward(self, x):
out = F.linear(x, self.weight, self.bias)
out = self.activation(out)
Expand All @@ -45,18 +45,20 @@ def forward(self, x):
# siren network

class SirenNet(nn.Module):
def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0 = 1., w0_initial = 30., use_bias = True, final_activation = None):
def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0 = 30., w0_initial = 30., use_bias = True, final_activation = None):
super().__init__()
layers = []
for ind in range(num_layers):
layer_w0 = w0_initial if ind == 0 else w0
layer_dim_in = dim_in if ind == 0 else dim_hidden
is_first = ind == 0
layer_w0 = w0_initial if is_first else w0
layer_dim_in = dim_in if is_first else dim_hidden

layers.append(Siren(
dim_in = layer_dim_in,
dim_out = dim_hidden,
w0 = layer_w0,
use_bias = use_bias
use_bias = use_bias,
is_first = is_first
))

self.net = nn.Sequential(*layers)
Expand Down

0 comments on commit 5160b9d

Please sign in to comment.