Skip to content

Commit

Permalink
add SirenWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 18, 2021
1 parent 3c602e7 commit da39fc6
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 3 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,30 @@ coor = torch.randn(1, 2)
act(coor)
```

Wrapper to return an image of specified height and width from a given `SirenNet`, for training and inference

```python
import torch
from torch import nn
from siren_pytorch import SirenNet, SirenWrapper

net = SirenNet(
dim_in = 2, # input dimension, ex. 2d coor
dim_hidden = 256, # hidden dimension
dim_out = 3, # output dimension, ex. rgb value
num_layers = 5, # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)

wrapper = SirenWrapper(
net,
image_width = 256,
image_height = 256
)

image = wrapper() # (1, 3, 256, 256)
```

## Citations

```bibtex
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
setup(
name = 'siren-pytorch',
packages = find_packages(),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'Implicit Neural Representations with Periodic Activation Functions',
author = 'Phil Wang',
author_email = '[email protected]',
url = 'https://github.com/lucidrains/siren-pytorch',
keywords = ['artificial intelligence', 'deep learning'],
install_requires=[
'einops',
'torch'
],
classifiers=[
Expand Down
2 changes: 1 addition & 1 deletion siren_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from siren_pytorch.siren_pytorch import Siren, SirenNet, Sine
from siren_pytorch.siren_pytorch import Sine, Siren, SirenNet, SirenWrapper
25 changes: 24 additions & 1 deletion siren_pytorch/siren_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import math
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange

# helpers

Expand Down Expand Up @@ -74,3 +75,25 @@ def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0 = 30., w0_initial
def forward(self, x):
x = self.net(x)
return self.last_layer(x)

# wrapper

class SirenWrapper(nn.Module):
def __init__(self, net, image_width, image_height):
super().__init__()
assert isinstance(net, SirenNet), 'SirenWrapper must receive a Siren network'

self.net = net
self.image_width = image_width
self.image_height = image_height

tensors = [torch.linspace(-1, 1, steps = image_width), torch.linspace(-1, 1, steps = image_height)]
mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
mgrid = rearrange(mgrid, 'h w c -> (h w) c')
self.register_buffer('grid', mgrid)

def forward(self):
coords = self.grid.clone().detach().requires_grad_()
out = self.net(coords)
out = rearrange(out, '(h w) c -> () c h w', h = self.image_height, w = self.image_width)
return out

0 comments on commit da39fc6

Please sign in to comment.