Skip to content

Commit

Permalink
Update prototype.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Jona te Lintelo committed Mar 22, 2023
1 parent dc4ec57 commit e5dc476
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions skeleton/models/prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
EvaluationPair,
evaluate_speaker_trials,
)
from skeleton.layers.resnet import ResNet
from skeleton.layers.resnext import ResNext

from skeleton.layers.statistical_pooling import MeanStatPool1D

Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
nn.ReLU(),
)

self.resnet = ResNet(((num_embedding, 2, num_embedding*2),(num_embedding*2, 2, num_embedding*4), (num_embedding*4, 2, num_embedding*8), (num_embedding*8, 2, num_embedding*16)))
self.resnet = ResNext(((num_embedding, 2, num_embedding*2),(num_embedding*2, 2, num_embedding*4), (num_embedding*4, 2, num_embedding*8), (num_embedding*8, 2, num_embedding*16)))

# Pooling layer
# assuming input of shape [BATCH_SIZE, NUM_EMBEDDING, REDUCED_NUM_FRAMES]
Expand Down Expand Up @@ -111,12 +111,12 @@ def forward(self, spectrogram: t.Tensor) -> Tuple[t.Tensor, t.Tensor]:
def compute_embedding(self, spectrogram: t.Tensor) -> t.Tensor:
# modify to your liking!
feature_representation = self.embedding_layer(spectrogram) # -> [128,128,239]
resnet_output = self.resnet(feature_representation)
output = self.resnext(feature_representation)


resnet_output = resnet_output[:, :, None] # -> ([128, 128, 1])
output = output[:, :, None] # -> ([128, 128, 1])

embedding = self.pooling_layer(resnet_output) # -> [128, 128]
embedding = self.pooling_layer(output) # -> [128, 128]
return embedding

def compute_prediction(self, embedding: t.Tensor) -> t.Tensor:
Expand Down

0 comments on commit e5dc476

Please sign in to comment.