Skip to content

Commit

Permalink
initial release
Browse files Browse the repository at this point in the history
  • Loading branch information
adambielski committed Mar 6, 2018
1 parent 3214a0b commit 4ed17ef
Show file tree
Hide file tree
Showing 30 changed files with 4,634 additions and 2 deletions.
1,888 changes: 1,888 additions & 0 deletions Experiments_FashionMNIST.ipynb

Large diffs are not rendered by default.

1,803 changes: 1,803 additions & 0 deletions Experiments_MNIST.ipynb

Large diffs are not rendered by default.

194 changes: 192 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,192 @@
# siamese-triplet
Siamese and triplet networks with online pair/triplet mining in PyTorch
# Siamese and triplet learning with online pair/triplet mining

PyTorch implementation of siamese and triplet networks for learning embeddings.

Siamese and triplet networks are useful to learn mappings from image to a compact Euclidean space where distances correspond to a measure of similarity [2]. Embeddings trained in such way can be used as features vectors for classification or few-shot learning tasks.

# Installation

Requires [pytorch](http://pytorch.org/) 0.3.1 with torchvision 0.2.0

# Code structure

- **datasets.py**
- *SiameseMNIST* class - wrapper for a MNIST-like dataset, returning random positive and negative pairs
- *TripletMNIST* class - wrapper for a MNIST-like dataset, returning random triplets (anchor, positive and negative)
- *BalancedBatchSampler* class - BatchSampler for data loader, randomly chooses *n_classes* and *n_samples* from each class of a MNIST-like dataset
- **networks.py**
- *EmbeddingNet* - base network for encoding images into embedding vector
- *ClassificationNet* - wrapper for an embedding network, adds a fully connected layer and log softmax for classification
- *SiameseNet* - wrapper for an embedding network, processes pairs of inputs
- *TripletNet* - wrapper for an embedding network, processes triplets of inputs
- **losses.py**
- *ContrastiveLoss* - contrastive loss for pairs of embeddings and pair target (same/different)
- *TripletLoss* - triplet loss for triplets of embeddings
- *OnlineContrastiveLoss* - contrastive loss for a mini-batch of embeddings. Uses a *PairSelector* object to find positive and negative pairs within a mini-batch using ground truth class labels and computes contrastive loss for these pairs
- *OnlineTripletLoss* - triplet loss for a mini-batch of embeddings. Uses a *TripletSelector* object to find triplets within a mini-batch using ground truth class labels and computes triplet loss
- **trainer.py**
- *fit* - unified function for training a network with different number of inputs and different types of loss functions
- **metrics.py**
- Sample metrics that can be used with *fit* function from *trainer.py*
- **utils.py**
- *PairSelector* - abstract class defining objects generating pairs based on embeddings and ground truth class labels. Can be used with *OnlineContrastiveLoss*.
- *AllPositivePairSelector, HardNegativePairSelector* - PairSelector implementations
- *TripletSelector* - abstract class defining objects generating triplets based on embeddings and ground truth class labels. Can be used with *OnlineTripletLoss*.
- *AllTripletSelector*, *HardestNegativeTripletSelector*, *RandomNegativeTripletSelector*, *SemihardNegativeTripletSelector* - TripletSelector implementations

# Examples

We'll train embeddings on MNIST dataset. Experiments were run in [jupyter notebook](Experiments_MNIST.ipynb).

We'll go through learning supervised feature embeddings using different loss functions on MNIST dataset. This is just for visualization purposes, thus we'll be using 2-dimensional embeddings which isn't the best choice in practice.

For every experiment the same embedding network is used (32 conv 5x5 -> PReLU -> MaxPool 2x2 -> 64 conv 5x5 -> PReLU -> MaxPool 2x2 -> Dense 256 -> PReLU -> Dense 256 -> PReLU -> Dense 2) and we don't perform any hyperparameter search.

## Baseline - classification with softmax

We add a fully-connected layer with the number of classes and train the network for classification with softmax and cross-entropy. The network trains to ~99% accuracy. We extract 2 dimensional embeddings from penultimate layer:

Train set:

![](images/mnist_softmax_train.png)

Test set:

![](images/mnist_softmax_test.png)

While the embeddings look separable (which is what we trained them for), they don't have good metric properties. They might not be the best choice as a descriptor for new classes.

## Siamese network

Now we'll train a siamese network that takes a pair of images and trains the embeddings so that the distance between them is minimized if they're from the same class and is greater than some margin value if they represent different classes.
We'll minimize a contrastive loss function [1]:
$$L_{contrastive}(x_0, x_1, y) = \frac{1}{2} y \lVert f(x_0)-f(x_1)\rVert_2^2 + \frac{1}{2}(1-y)\{max(0, m-\lVert f(x_0)-f(x_1)\rVert_2\}^2$$

*SiameseMNIST* class samples random positive and negative pairs that are then fed to Siamese Network.

After 20 epochs of training here are the embeddings we get for training set:

![](images/mnist_siamese_train.png)

Test set:

![](images/mnist_siamese_test.png)

The learned embeddings are clustered much better within class.

## Triplet network

We'll train a triplet network, that takes an anchor, a positive (of same class as an anchor) and negative (of different class than an anchor) examples. The objective is to learn embeddings such that the anchor is closer to the positive example than it is to the negative example by some margin value.

![alt text](images/anchor_negative_positive.png "Source: FaceNet")
Source: *Schroff, Florian, Dmitry Kalenichenko, and James Philbin. [Facenet: A unified embedding for face recognition and clustering.](https://arxiv.org/abs/1503.03832) CVPR 2015.*

**Triplet loss**: $L_{triplet}(x_a, x_p, x_n) = m + \lVert f(x_a)-f(x_p)\rVert_2^2 - \lVert f(x_a)-f(x_n)\rVert_2^2$

*TripletMNIST* class samples a positive and negative example for every possible anchor.

After 20 epochs of training here are the embeddings we get for training set:

![](images/mnist_triplet_train.png)

Test set:

![](images/mnist_triplet_test.png)

The learned embeddings are not as close to each other within class as in case of siamese network, but that's not what we optimized them for. We wanted the embeddings to be closer to other embeddings from the same class than from the other classes and we can see that's where the training is going to.

## Online pair/triplet selection - negative mining

There are couple of problems with siamese and triplet networks:
1. The **number of possible pairs/triplets** grows **quadratically/cubically** with the number of examples. It's infeasible to process them all and the training converges slowly.
2. We generate pairs/triplets *randomly*. As the training continues, more and more pairs/triplets are **easy** to deal with (their loss value is very small or even 0), *preventing the network from training*. We need to provide the network with **hard examples**.
3. Each image that is fed to the network is used only for computation of contrastive/triplet loss for only one pair/triplet. The computation is somewhat wasted; once the embedding is computed, it could be reused for many pairs/triplets.

To deal with these issues efficiently, we'll feed a network with standard mini-batches as we did for classification. The loss function will be responsible for selection of hard pairs and triplets within mini-batch. If we feed the network with 16 images per 10 classes, we can process up to $159*160/2 = 12720$ pairs and $10*16*15/2*(9*16) = 172800$ triplets, compared to 80 pairs and 53 triplets in previous implementation.

Usually it's not the best idea to process all possible pairs or triplets within a mini-batch. We can find some strategies on how to select triplets in [2] and [3].

### Online pair selection

We'll feed a network with mini-batches, as we did for classification network. This time we'll use a special BatchSampler that will sample *n_classes* and *n_samples* within each class, resulting in mini batches of size *n_classes\*n_samples*.

For each mini batch positive and negative pairs will be selected using provided labels.

MNIST is a rather easy dataset and the embeddings from the randomly selected pairs were quite good already, we don't see much improvement here.

**Train embeddings:**

![](images/mnist_ocl_train.png)

**Test embeddings:**

![](images/mnist_ocl_test.png)

### Online triplet selection

We'll feed a network with mini-batches just like with online pair selection. There are couple of strategies we can use for triplet selection given labels and predicted embeddings:

- All possible triplets (might be too many)
- Hardest negative for each positive pair (will result in the same negative for each anchor)
- Random hard negative for each positive pair (consider only triplets with positive triplet loss value)
- Semi-hard negative for each positive pair (similar to [2])

The strategy for triplet selection must be chosen carefully. A bad strategy might lead to inefficient training or, even worse, to model collapsing (all embeddings ending up having the same values).

Here's what we got with random hard negatives for each positive pair.

**Training set:**

![](images/mnist_otl_train.png)

**Test set:**

![](images/mnist_otl_test.png)

# FashionMNIST

Similar experiments were conducted for FashionMNIST dataset where advantages of online negative mining are more visible. The exact same network architecture with only 2-dimensional embeddings was used, which is probably not complex enough for learning good embeddings.

## Baseline - classification

![](images/fmnist_softmax_test.png)

## Siamese vs online contrastive loss with negative mining

Siamese network with randomly selected pairs

![](images/fmnist_siamese_test.png)

Online contrastive loss with negative mining

![](images/fmnist_ocl_test.png)

## Triplet vs online triplet loss with negative mining

Triplet network with random triplets

![](images/fmnist_triplet_test.png)

Online triplet loss with negative mining

![](images/fmnist_otl_test.png)

# TODO

- [ ] Optimize triplet selection
- [ ] Evaluate with a metric that is comparable between approaches
- [ ] Evaluate in one-shot setting when classes from test set are not in train set
- [ ] Show online triplet selection example on more difficult datasets

# References

[1] Raia Hadsell, Sumit Chopra, Yann LeCun, [Dimensionality reduction by learning an invariant mapping](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf), CVPR 2006

[2] Schroff, Florian, Dmitry Kalenichenko, and James Philbin. [Facenet: A unified embedding for face recognition and clustering.](https://arxiv.org/abs/1503.03832) CVPR 2015

[3] Alexander Hermans, Lucas Beyer, Bastian Leibe, [In Defense of the Triplet Loss for Person Re-Identification](https://arxiv.org/pdf/1703.07737), 2017

[4] Brandon Amos, Bartosz Ludwiczuk, Mahadev Satyanarayanan, [OpenFace: A general-purpose face recognition library with mobile applications](http://reports-archive.adm.cs.cmu.edu/anon/2016/CMU-CS-16-118.pdf), 2016

[5] Yi Sun, Xiaogang Wang, Xiaoou Tang, [Deep Learning Face Representation by Joint Identification-Verification](http://papers.nips.cc/paper/5416-deep-learning-face-representation-by-joint-identification-verification), NIPS 2014

188 changes: 188 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import numpy as np
from PIL import Image

from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler


class SiameseMNIST(Dataset):
"""
Train: For each sample creates randomly a positive or a negative pair
Test: Creates fixed pairs for testing
"""

def __init__(self, mnist_dataset):
self.mnist_dataset = mnist_dataset

self.train = self.mnist_dataset.train
self.transform = self.mnist_dataset.transform

if self.train:
self.train_labels = self.mnist_dataset.train_labels
self.train_data = self.mnist_dataset.train_data
self.labels_set = set(self.train_labels.numpy())
self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
for label in self.labels_set}
else:
# generate fixed pairs for testing
self.test_labels = self.mnist_dataset.test_labels
self.test_data = self.mnist_dataset.test_data
self.labels_set = set(self.test_labels.numpy())
self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
for label in self.labels_set}

random_state = np.random.RandomState(29)

positive_pairs = [[i,
random_state.choice(self.label_to_indices[self.test_labels[i]]),
1]
for i in range(0, len(self.test_data), 2)]

negative_pairs = [[i,
random_state.choice(self.label_to_indices[
np.random.choice(
list(self.labels_set - set([self.test_labels[i]]))
)
]),
0]
for i in range(1, len(self.test_data), 2)]
self.test_pairs = positive_pairs + negative_pairs

def __getitem__(self, index):
if self.train:
target = np.random.randint(0, 2)
img1, label1 = self.train_data[index], self.train_labels[index]
if target == 1:
siamese_index = index
while siamese_index == index:
siamese_index = np.random.choice(self.label_to_indices[label1])
else:
siamese_label = np.random.choice(list(self.labels_set - set([label1])))
siamese_index = np.random.choice(self.label_to_indices[siamese_label])
img2 = self.train_data[siamese_index]
else:
img1 = self.test_data[self.test_pairs[index][0]]
img2 = self.test_data[self.test_pairs[index][1]]
target = self.test_pairs[index][2]

img1 = Image.fromarray(img1.numpy(), mode='L')
img2 = Image.fromarray(img2.numpy(), mode='L')
if self.transform is not None:
img1 = self.transform(img1)
img2 = self.transform(img2)
return (img1, img2), target

def __len__(self):
return len(self.mnist_dataset)


class TripletMNIST(Dataset):
"""
Train: For each sample (anchor) randomly chooses a positive and negative samples
Test: Creates fixed triplets for testing
"""

def __init__(self, mnist_dataset):
self.mnist_dataset = mnist_dataset
self.train = self.mnist_dataset.train
self.transform = self.mnist_dataset.transform

if self.train:
self.train_labels = self.mnist_dataset.train_labels
self.train_data = self.mnist_dataset.train_data
self.labels_set = set(self.train_labels.numpy())
self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
for label in self.labels_set}

else:
self.test_labels = self.mnist_dataset.test_labels
self.test_data = self.mnist_dataset.test_data
# generate fixed triplets for testing
self.labels_set = set(self.test_labels.numpy())
self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
for label in self.labels_set}

random_state = np.random.RandomState(29)

triplets = [[i,
random_state.choice(self.label_to_indices[self.test_labels[i]]),
random_state.choice(self.label_to_indices[
np.random.choice(
list(self.labels_set - set([self.test_labels[i]]))
)
])
]
for i in range(len(self.test_data))]
self.test_triplets = triplets

def __getitem__(self, index):
if self.train:
img1, label1 = self.train_data[index], self.train_labels[index]
positive_index = index
while positive_index == index:
positive_index = np.random.choice(self.label_to_indices[label1])
negative_label = np.random.choice(list(self.labels_set - set([label1])))
negative_index = np.random.choice(self.label_to_indices[negative_label])
img2 = self.train_data[positive_index]
img3 = self.train_data[negative_index]
else:
img1 = self.test_data[self.test_triplets[index][0]]
img2 = self.test_data[self.test_triplets[index][1]]
img3 = self.test_data[self.test_triplets[index][2]]

img1 = Image.fromarray(img1.numpy(), mode='L')
img2 = Image.fromarray(img2.numpy(), mode='L')
img3 = Image.fromarray(img3.numpy(), mode='L')
if self.transform is not None:
img1 = self.transform(img1)
img2 = self.transform(img2)
img3 = self.transform(img3)
return (img1, img2, img3), []

def __len__(self):
return len(self.mnist_dataset)




class BalancedBatchSampler(BatchSampler):
"""
BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
Returns batches of size n_classes * n_samples
"""

def __init__(self, dataset, n_classes, n_samples):
if dataset.train:
self.labels = dataset.train_labels
else:
self.labels = dataset.test_labels
self.labels_set = list(set(self.labels.numpy()))
self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
for label in self.labels_set}
for l in self.labels_set:
np.random.shuffle(self.label_to_indices[l])
self.used_label_indices_count = {label: 0 for label in self.labels_set}
self.count = 0
self.n_classes = n_classes
self.n_samples = n_samples
self.dataset = dataset
self.batch_size = self.n_samples * self.n_classes

def __iter__(self):
self.count = 0
while self.count + self.batch_size < len(self.dataset):
classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
indices = []
for class_ in classes:
indices.extend(self.label_to_indices[class_][
self.used_label_indices_count[class_]:self.used_label_indices_count[
class_] + self.n_samples])
self.used_label_indices_count[class_] += self.n_samples
if self.used_label_indices_count[class_] + self.n_samples < len(self.label_to_indices[class_]):
np.random.shuffle(self.label_to_indices[class_])
self.used_label_indices_count[class_] = 0
yield indices
self.count += self.n_classes * self.n_samples

def __len__(self):
return len(self.dataset) // self.batch_size
Binary file added images/anchor_negative_positive.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/fmnist_ocl_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/fmnist_ocl_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/fmnist_otl_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/fmnist_otl_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/fmnist_siamese_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/fmnist_siamese_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/fmnist_softmax_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/fmnist_softmax_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/fmnist_triplet_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/fmnist_triplet_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/mnist_ocl_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/mnist_ocl_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/mnist_otl_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/mnist_otl_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/mnist_siamese_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/mnist_siamese_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/mnist_softmax_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/mnist_softmax_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/mnist_triplet_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/mnist_triplet_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 4ed17ef

Please sign in to comment.