A fast implementation of the histogram loss in pytorch, and the original paper can be found here:
- Learning Deep Embeddings with Histogram Loss - Evgeniya Ustinova, Victor Lempitsky
Both forward and backward functions are implemented, so it can be used as a loss function in your own work. This version is rather stable on both CPUs and GPUs as no outstanding errors occurred during tests.
This implementation is based on two pieces of information available online about pytorch:
- torch.bincount - The very fast
bincount
function in pytorch - Extending Pytorch - Writing your own customised layer with both forward and backward functions.
pytorch >= v0.4.1
Import the function into python
from hist_loss import HistogramLoss
Initialise an instance of the function
func_loss = HistogramLoss()
Forward computation
loss = func_loss(sim_pos, sim_neg, n_bins, w_pos, w_neg)
Backward computation
loss.backward()
- [email protected] - Email
- @Shuai93Tang - Twitter
- Shuai Tang - Homepage