Skip to content

Commit

Permalink
Update prototype.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Jona te Lintelo committed Mar 22, 2023
1 parent e5dc476 commit 0a6cde0
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions skeleton/models/prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def __init__(
nn.ReLU(),
)

# Change prediction_layer in_featuers to the output size of the resnet model used.
# Change the resnet called in the compute_embedding to the resnet model used.
# Change the compute_embedding function entirely if ResNet is used instead of ResNeXt.
self.resnet10 = ResNet(((32,2,64),(64,2,128)))
self.resnet18 = ResNet(((32,2,64),(64,2,128),(128,2,256),(256,2,512)))
self.resnet34 = ResNet(((32,3,64),(64,4,128),(128,6,256),(256,3,512)))
self.resnet = ResNext(((num_embedding, 2, num_embedding*2),(num_embedding*2, 2, num_embedding*4), (num_embedding*4, 2, num_embedding*8), (num_embedding*8, 2, num_embedding*16)))

# Pooling layer
Expand Down Expand Up @@ -119,6 +125,12 @@ def compute_embedding(self, spectrogram: t.Tensor) -> t.Tensor:
embedding = self.pooling_layer(output) # -> [128, 128]
return embedding

# Use this function if ResNet is used instead of ResNeXt.
# def compute_embedding(self, spectrogram: t.Tensor) -> t.Tensor:
# resnet_output = self.resnet10(spectrogram)
# embedding = self.pooling_layer(resnet_output)
# return embedding

def compute_prediction(self, embedding: t.Tensor) -> t.Tensor:
# modify to your liking!
# embedding = embedding[None, :, :]
Expand Down

0 comments on commit 0a6cde0

Please sign in to comment.