From 0e7e754ee5f8a84ddb6102518e71f1e0205e76dd Mon Sep 17 00:00:00 2001 From: John Schreck Date: Wed, 10 Jul 2024 09:19:38 -0600 Subject: [PATCH] Restructure of original evidential loss + renaming --- mlguess/torch/regression_losses.py | 66 +++++++++++++++--------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/mlguess/torch/regression_losses.py b/mlguess/torch/regression_losses.py index 481996b..bb52ef2 100644 --- a/mlguess/torch/regression_losses.py +++ b/mlguess/torch/regression_losses.py @@ -7,46 +7,46 @@ tol = torch.finfo(torch.float32).eps -def nig_nll(y, gamma, v, alpha, beta): - """Implements Normal Inverse Gamma-Negative Log Likelihood for - Deep Evidential Regression +# original loss used in Amini paper - Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf - Source: https://github.com/hxu296/torch-evidental-deep-learning - """ - two_blambda = 2 * beta * (1 + v) + tol - nll = 0.5 * torch.log(np.pi / (v + tol)) \ - - alpha * torch.log(two_blambda + tol) \ - + (alpha + 0.5) * torch.log(v * (y - gamma) ** 2 + two_blambda + tol) \ - + torch.lgamma(alpha) \ - - torch.lgamma(alpha + 0.5) - - return nll +class EvidentialRegressionLoss: + def __init__(self, coef=1.0): + self.coef = coef -def nig_reg(y, gamma, v, alpha): - """Implements Normal Inverse Gamma Regularizer for Deep Evidential - Regression + def normal_inverse_gamma_nll(self, y, gamma, v, alpha, beta): + """Implements Normal Inverse Gamma-Negative Log Likelihood for + Deep Evidential Regression - Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf - Source: https://github.com/hxu296/torch-evidental-deep-learning - """ - error = F.l1_loss(y, gamma, reduction="none") - evi = 2 * v + alpha - return error * evi + Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf + Source: https://github.com/hxu296/torch-evidental-deep-learning + """ + two_blambda = 2 * beta * (1 + v) + tol + nll = 0.5 * torch.log(np.pi / (v + tol)) \ + - alpha * torch.log(two_blambda + tol) \ + + (alpha + 0.5) * torch.log(v * (y - gamma) ** 2 + two_blambda + tol) \ + + torch.lgamma(alpha) \ + - torch.lgamma(alpha + 0.5) + return nll -def evidential_regression_loss(y, pred, coef=1.0): - """Implements Evidential Regression Loss for Deep Evidential - Regression + def normal_inverse_gamma_reg(self, y, gamma, v, alpha, beta): + """Implements Normal Inverse Gamma Regularizer for Deep Evidential + Regression - Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf - Source: https://github.com/hxu296/torch-evidental-deep-learning - """ - gamma, v, alpha, beta = pred - loss_nll = nig_nll(y, gamma, v, alpha, beta) - loss_reg = nig_reg(y, gamma, v, alpha, beta) - return loss_nll.mean() + coef * loss_reg.mean() + Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf + Source: https://github.com/hxu296/torch-evidental-deep-learning + """ + error = F.l1_loss(y, gamma, reduction="none") + evi = 2 * v + alpha + return error * evi + + def __call__(self, y, pred): + """Calculate the Evidential Regression Loss""" + gamma, v, alpha, beta = pred + loss_nll = self.normal_inverse_gamma_nll(y, gamma, v, alpha, beta) + loss_reg = self.normal_inverse_gamma_reg(y, gamma, v, alpha, beta) + return loss_nll.mean() + self.coef * loss_reg.mean() # code below based off https://github.com/deargen/MT-ENet