Skip to content

Commit

Permalink
Restructure of original evidential loss + renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
jsschreck committed Jul 10, 2024
1 parent 41b008b commit 0e7e754
Showing 1 changed file with 33 additions and 33 deletions.
66 changes: 33 additions & 33 deletions mlguess/torch/regression_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,46 @@
tol = torch.finfo(torch.float32).eps


def nig_nll(y, gamma, v, alpha, beta):
"""Implements Normal Inverse Gamma-Negative Log Likelihood for
Deep Evidential Regression
# original loss used in Amini paper

Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf
Source: https://github.com/hxu296/torch-evidental-deep-learning
"""
two_blambda = 2 * beta * (1 + v) + tol
nll = 0.5 * torch.log(np.pi / (v + tol)) \
- alpha * torch.log(two_blambda + tol) \
+ (alpha + 0.5) * torch.log(v * (y - gamma) ** 2 + two_blambda + tol) \
+ torch.lgamma(alpha) \
- torch.lgamma(alpha + 0.5)

return nll

class EvidentialRegressionLoss:
def __init__(self, coef=1.0):
self.coef = coef

def nig_reg(y, gamma, v, alpha):
"""Implements Normal Inverse Gamma Regularizer for Deep Evidential
Regression
def normal_inverse_gamma_nll(self, y, gamma, v, alpha, beta):
"""Implements Normal Inverse Gamma-Negative Log Likelihood for
Deep Evidential Regression
Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf
Source: https://github.com/hxu296/torch-evidental-deep-learning
"""
error = F.l1_loss(y, gamma, reduction="none")
evi = 2 * v + alpha
return error * evi
Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf
Source: https://github.com/hxu296/torch-evidental-deep-learning
"""
two_blambda = 2 * beta * (1 + v) + tol
nll = 0.5 * torch.log(np.pi / (v + tol)) \
- alpha * torch.log(two_blambda + tol) \
+ (alpha + 0.5) * torch.log(v * (y - gamma) ** 2 + two_blambda + tol) \
+ torch.lgamma(alpha) \
- torch.lgamma(alpha + 0.5)

return nll

def evidential_regression_loss(y, pred, coef=1.0):
"""Implements Evidential Regression Loss for Deep Evidential
Regression
def normal_inverse_gamma_reg(self, y, gamma, v, alpha, beta):
"""Implements Normal Inverse Gamma Regularizer for Deep Evidential
Regression
Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf
Source: https://github.com/hxu296/torch-evidental-deep-learning
"""
gamma, v, alpha, beta = pred
loss_nll = nig_nll(y, gamma, v, alpha, beta)
loss_reg = nig_reg(y, gamma, v, alpha, beta)
return loss_nll.mean() + coef * loss_reg.mean()
Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf
Source: https://github.com/hxu296/torch-evidental-deep-learning
"""
error = F.l1_loss(y, gamma, reduction="none")
evi = 2 * v + alpha
return error * evi

def __call__(self, y, pred):
"""Calculate the Evidential Regression Loss"""
gamma, v, alpha, beta = pred
loss_nll = self.normal_inverse_gamma_nll(y, gamma, v, alpha, beta)
loss_reg = self.normal_inverse_gamma_reg(y, gamma, v, alpha, beta)
return loss_nll.mean() + self.coef * loss_reg.mean()


# code below based off https://github.com/deargen/MT-ENet
Expand Down

0 comments on commit 0e7e754

Please sign in to comment.