Skip to content

A fast implementation of the histogram loss in pytorch

Notifications You must be signed in to change notification settings

desa-lab/HistLoss

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 

Repository files navigation

Histogram Loss

A fast implementation of the histogram loss in pytorch, and the original paper can be found here:

Getting started

Both forward and backward functions are implemented, so it can be used as a loss function in your own work. This version is rather stable on both CPUs and GPUs as no outstanding errors occurred during tests.

Implementation

This implementation is based on two pieces of information available online about pytorch:

  • torch.bincount - The very fast bincount function in pytorch
  • Extending Pytorch - Writing your own customised layer with both forward and backward functions.

Prerequisites

pytorch >= v0.4.1

Running

Import the function into python

from hist_loss import HistogramLoss

Initialise an instance of the function

func_loss = HistogramLoss()

Forward computation

loss = func_loss(sim_pos, sim_neg, n_bins, w_pos, w_neg)

Backward computation

loss.backward()

Contact

About

A fast implementation of the histogram loss in pytorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages