Skip to content

v0.9.92

Compare
Choose a tag to compare
@KevinMusgrave KevinMusgrave released this 14 Sep 09:53
· 830 commits to master since this release
0e9161e

New Features

DistributedLossWrapper and DistributedMinerWrapper

Added DistributedLossWrapper and DistributedMinerWrapper. Wrap a loss or miner with these when using PyTorch's DistributedDataParallel (i.e. multiprocessing). Most of the code is by @JohnGiorgi (https://github.com/JohnGiorgi/DeCLUTR).

from pytorch_metric_learning import losses, miners
from pytorch_metric_learning.utils import distributed as pml_dist
loss_func = pml_dist.DistributedLossWrapper(loss = losses.ContrastiveLoss())
miner = pml_dist.DistributedMinerWrapper(miner = miners.MultiSimilarityMiner())

For a working example, see the "Multiprocessing with DistributedDataParallel" notebook.

Added enqueue_idx to CrossBatchMemory

Now you can make CrossBatchMemory work with MoCo. This adds a great deal of flexibility to the MoCo framework, because you can use any tuple loss and tuple miner in CrossBatchMemory.

Previously this wasn't possible because all embeddings passed into CrossBatchMemory would go into the memory queue. In contrast, MoCo only queues the momentum encoder's embeddings.

The new enqueue_idx argument lets you do this, by specifying which embeddings should be added to memory. Here's a modified snippet from the MoCo on CIFAR10 notebook:

from pytorch_metric_learning.losses import CrossBatchMemory, NTXentLoss

loss_fn = CrossBatchMemory(loss = NTXentLoss(), embedding_size = 64, memory_size = 16384)

### snippet from the training loop ###
for images, _ in train_loader:
  ...
  previous_max_label = torch.max(loss_fn.label_memory)
  num_pos_pairs = encQ_out.size(0)
  labels = torch.arange(0, num_pos_pairs)
  labels = torch.cat((labels , labels)).to(device)

  ### add an offset so that the labels do not overlap with any labels in the memory queue ###
  labels += previous_max_label + 1

  ### we want to enqueue the output of encK, which is the 2nd half of the batch ###
  enqueue_idx = torch.arange(num_pos_pairs, num_pos_pairs*2)

  all_enc = torch.cat([encQ_out, encK_out], dim=0)

  ### now only encK_out will be added to the memory queue ###
  loss = loss_fn(all_enc, labels, enqueue_idx = enqueue_idx)
  ...

Check out the MoCo on CIFAR10 notebook to see the entire script.

TuplesToWeightsSampler

This is a simple offline miner. It does the following:

  1. Take a random subset of your dataset, if you provide subset_size
  2. Use a specified miner to mine tuples from the subset dataset.
  3. Compute weights based on how often an element appears in the mined tuples.
  4. Randomly sample, using the weights as probabilities.
from pytorch_metric_learning.samplers import TuplesToWeightsSampler
from pytorch_metric_learning.miners import MultiSimilarityMiner

miner = MultiSimilarityMiner(epsilon=-0.2)
sampler = TuplesToWeightsSampler(model, miner, dataset, subset_size = 5000)
# then pass the sampler into your Dataloader

LogitGetter

Added utils.inference.LogitGetter to make it easier to compute logits of classifier loss functions.

from pytorch_metric_learning.losses import ArcFaceLoss
from pytorch_metric_learning.utils.inference import LogitGetter

loss_fn = ArcFaceLoss(num_classes = 100, embedding_size = 512)
LG = LogitGetter(loss_fn)
logits = LG(embeddings)

Other

  • Added optional batch_size argument to MPerClassSampler. If you pass in this argument, then each batch is guaranteed to have m samples per class. Otherwise, most batches will have m samples per class, but it's not guaranteed for every batch. Note there restrictions on the values of m and batch_size. For example, batch_size must be a multiple of m. For all the restrictions, see the documentation.

  • Added trainable_attributes to BaseTrainer and to standardize the set_to_train and set_to_eval functions.

  • Added save_models init argument to HookContainer. If set to False then models will not be saved.

  • Added losses_sizes as a stat for BaseReducer

  • Added a type check and conversion in common_functions.labels_to_indices to go from torch tensor to numpy