Skip to content

Commit

Permalink
Add new file
Browse files Browse the repository at this point in the history
  • Loading branch information
Jona te Lintelo committed Mar 22, 2023
1 parent 5aaec71 commit d8d1d43
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions skeleton/layers/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch as t
import torch.nn as nn

from skeleton.layers.ResidualBlock import ResidualBlock

class ResNet(nn.Module):
def __init__(self, triples):
super(ResNet, self).__init__()
modules = []
for i, triple in enumerate(triples):
block = self.block(*triple)
modules.append(block)

self.sb = self.starting_block(40, 32)

self.net = nn.Sequential(*modules)

def starting_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=7, stride=2, padding=3),
nn.BatchNorm1d(out_channels),
nn.ReLU(),
nn.MaxPool1d(kernel_size=3, stride=2, padding=1))

def block(self, in_channels, num_residuals, out_channels):
blk = []
for i in range(num_residuals):
if i == 0:
blk.append(ResidualBlock(in_channels, out_channels))
else:
blk.append(ResidualBlock(in_channels*2, out_channels))
return nn.Sequential(*blk)

def forward(self, x):
x = self.sb(x)
x = self.net(x)
return x

0 comments on commit d8d1d43

Please sign in to comment.