From d7535a58eaefb1c640a4ed15ef77f00237617020 Mon Sep 17 00:00:00 2001 From: thsno02 Date: Mon, 4 Sep 2023 17:44:02 +0800 Subject: [PATCH 1/4] feat: add 2 weight functions in multi-label --- ppcls/loss/dmlloss.py | 2 +- ppcls/loss/multilabelloss.py | 68 ++++++++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/ppcls/loss/dmlloss.py b/ppcls/loss/dmlloss.py index e8983ed08a..75d3756343 100644 --- a/ppcls/loss/dmlloss.py +++ b/ppcls/loss/dmlloss.py @@ -16,7 +16,7 @@ import paddle.nn as nn import paddle.nn.functional as F -from ppcls.loss.multilabelloss import ratio2weight +from ppcls.loss.multilabelloss import ratio2weight_1 class DMLLoss(nn.Layer): diff --git a/ppcls/loss/multilabelloss.py b/ppcls/loss/multilabelloss.py index a88d8265a0..ea37cdf344 100644 --- a/ppcls/loss/multilabelloss.py +++ b/ppcls/loss/multilabelloss.py @@ -3,7 +3,15 @@ import paddle.nn.functional as F -def ratio2weight(targets, ratio): +def ratio2weight_1(targets, ratio): + ''' + Math formula: + ``` + w_j = y_{ij} * e^{1 - r_j} + (1 - {y_ij}) * e^{r_j} + ``` + REF: https://arxiv.org/abs/2107.03576v2 + ''' + pos_weights = targets * (1. - ratio) neg_weights = (1. - targets) * ratio weights = paddle.exp(neg_weights + pos_weights) @@ -14,16 +22,62 @@ def ratio2weight(targets, ratio): return weights +def ratio2weight_2(targets, ratio): + ''' + Math formula: + ``` + w_j = y_{ij} * \sqrt{\frac{1}{2 * r_j}} + (1 - {y_ij}) * \sqrt{\frac{1}{2 * (1 - r_j)}} + ``` + REF: https://arxiv.org/abs/2107.03576v2 + ''' + + pos_weights = targets * ratio + neg_weights = (1. - targets) * (1 - ratio) + weights = paddle.sqrt(0.5 * paddle.reciprocal(neg_weights + pos_weights)) + + # for RAP dataloader, targets element may be 2, with or without smooth, some element must great than 1 + weights = weights - weights * (targets > 1) + + return weights + + +def ratio2weight_3(targets, ratio, alpha): + ''' + Math formula: + ``` + w_j = y_{ij} * \frac{(1/r_j)^\alpha}{(1/r_j)^\alpha + (1/(1 - r_j))^\alpha} + + (1 - {y_ij}) * \frac{(1/(1 - r_j))^\alpha}{(1/r_j)^\alpha + (1/(1 - r_j))^\alpha} + ``` + REF: https://arxiv.org/abs/2107.03576v2 + ''' + + pos_weights = targets * ratio + neg_weights = (1. - targets) * (1 - ratio) + combined_weights = pos_weights + neg_weights + weights = paddle.divide( + paddle.reciprocal(paddle.pow(combined_weights, alpha)), + paddle.reciprocal(paddle.pow(combined_weights, alpha)) + + paddle.reciprocal((paddle.pow(1 - combined_weights, alpha))) + ) + + # for RAP dataloader, targets element may be 2, with or without smooth, some element must great than 1 + weights = weights - weights * (targets > 1) + + return weights + + class MultiLabelLoss(nn.Layer): """ Multi-label loss """ - def __init__(self, epsilon=None, size_sum=False, weight_ratio=False): + def __init__(self, epsilon=None, size_sum=False, weight_type=1, weight_ratio=False, weight_alpha=False): super().__init__() if epsilon is not None and (epsilon <= 0 or epsilon >= 1): epsilon = None self.epsilon = epsilon + self.weight_type = weight_type + self.weight_alpha = weight_alpha self.weight_ratio = weight_ratio self.size_sum = size_sum @@ -46,7 +100,15 @@ def _binary_crossentropy(self, input, target, class_num): if self.weight_ratio: targets_mask = paddle.cast(target > 0.5, 'float32') - weight = ratio2weight(targets_mask, paddle.to_tensor(label_ratio)) + if self.weight_type == 2: + weight = ratio2weight_2( + targets_mask, paddle.to_tensor(label_ratio)) + elif self.weight_type == 3: + weight = ratio2weight_3(targets_mask, paddle.to_tensor( + label_ratio), self.weight_alpha) + else: + weight = ratio2weight_1( + targets_mask, paddle.to_tensor(label_ratio)) weight = weight * (target > -1) cost = cost * weight From 2020168df34e36c61addd55ac0e9cc2722aefc71 Mon Sep 17 00:00:00 2001 From: thsno02 Date: Wed, 13 Sep 2023 14:01:41 +0800 Subject: [PATCH 2/4] feat: add 2 weight functions in multi-label --- ppcls/loss/dmlloss.py | 16 ++++++++-- ppcls/loss/multilabelloss.py | 57 ++++++++++++++++++++++++++++++++++-- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/ppcls/loss/dmlloss.py b/ppcls/loss/dmlloss.py index 75d3756343..820860a137 100644 --- a/ppcls/loss/dmlloss.py +++ b/ppcls/loss/dmlloss.py @@ -16,7 +16,7 @@ import paddle.nn as nn import paddle.nn.functional as F -from ppcls.loss.multilabelloss import ratio2weight_1 +from ppcls.loss.multilabelloss import ratio2weight_1, ratio2weight_2, ratio2weight_3 class DMLLoss(nn.Layer): @@ -24,7 +24,7 @@ class DMLLoss(nn.Layer): DMLLoss """ - def __init__(self, act="softmax", sum_across_class_dim=False, eps=1e-12): + def __init__(self, act="softmax", sum_across_class_dim=False, eps=1e-12, weight_type=1, weight_alpha=0.1): super().__init__() if act is not None: assert act in ["softmax", "sigmoid"] @@ -36,6 +36,8 @@ def __init__(self, act="softmax", sum_across_class_dim=False, eps=1e-12): self.act = None self.eps = eps self.sum_across_class_dim = sum_across_class_dim + self.weight_type = weight_type + self.weight_alpha = weight_alpha def _kldiv(self, x, target): class_num = x.shape[-1] @@ -54,7 +56,15 @@ def forward(self, x, target, gt_label=None): if gt_label is not None: gt_label, label_ratio = gt_label[:, 0, :], gt_label[:, 1, :] targets_mask = paddle.cast(gt_label > 0.5, 'float32') - weight = ratio2weight(targets_mask, paddle.to_tensor(label_ratio)) + if self.weight_type == 2: + weight = ratio2weight_2( + targets_mask, paddle.to_tensor(label_ratio)) + elif self.weight_type == 3: + weight = ratio2weight_3( + targets_mask, paddle.to_tensor(label_ratio), self.weight_alpha) + else: + weight = ratio2weight_1( + targets_mask, paddle.to_tensor(label_ratio)) weight = weight * (gt_label > -1) loss = loss * weight diff --git a/ppcls/loss/multilabelloss.py b/ppcls/loss/multilabelloss.py index ea37cdf344..ee83a07795 100644 --- a/ppcls/loss/multilabelloss.py +++ b/ppcls/loss/multilabelloss.py @@ -11,7 +11,6 @@ def ratio2weight_1(targets, ratio): ``` REF: https://arxiv.org/abs/2107.03576v2 ''' - pos_weights = targets * (1. - ratio) neg_weights = (1. - targets) * ratio weights = paddle.exp(neg_weights + pos_weights) @@ -71,14 +70,14 @@ class MultiLabelLoss(nn.Layer): Multi-label loss """ - def __init__(self, epsilon=None, size_sum=False, weight_type=1, weight_ratio=False, weight_alpha=False): + def __init__(self, epsilon=None, size_sum=False, weight_ratio=False, weight_type=1, weight_alpha=0.1): super().__init__() if epsilon is not None and (epsilon <= 0 or epsilon >= 1): epsilon = None self.epsilon = epsilon + self.weight_ratio = weight_ratio self.weight_type = weight_type self.weight_alpha = weight_alpha - self.weight_ratio = weight_ratio self.size_sum = size_sum def _labelsmoothing(self, target, class_num): @@ -124,3 +123,55 @@ def forward(self, x, target): loss = self._binary_crossentropy(x, target, class_num) loss = loss.mean() return {"MultiLabelLoss": loss} + + +class MultiLabelAsymmetricLoss(nn.Layer): + """ + Multi-label asymmetric loss, introduced by + Emanuel Ben-Baruch at el. in https://arxiv.org/pdf/2009.14119v4.pdf. + """ + + def __init__(self, + gamma_pos=1, + gamma_neg=4, + clip=0.05, + epsilon=1e-8, + disable_focal_loss_grad=True, + reduction="sum"): + super().__init__() + self.gamma_pos = gamma_pos + self.gamma_neg = gamma_neg + self.clip = clip + self.epsilon = epsilon + self.disable_focal_loss_grad = disable_focal_loss_grad + assert reduction in ["mean", "sum", "none"] + self.reduction = reduction + + def forward(self, x, target): + if isinstance(x, dict): + x = x["logits"] + pred_sigmoid = F.sigmoid(x) + target = target.astype(pred_sigmoid.dtype) + + # Asymmetric Clipping and Basic CE calculation + if self.clip and self.clip > 0: + pt = (1 - pred_sigmoid + self.clip).clip(max=1) \ + * (1 - target) + pred_sigmoid * target + else: + pt = (1 - pred_sigmoid) * (1 - target) + pred_sigmoid * target + + # Asymmetric Focusing + if self.disable_focal_loss_grad: + paddle.set_grad_enabled(False) + asymmetric_weight = (1 - pt).pow( + self.gamma_pos * target + self.gamma_neg * (1 - target)) + if self.disable_focal_loss_grad: + paddle.set_grad_enabled(True) + + loss = -paddle.log(pt.clip(min=self.epsilon)) * asymmetric_weight + + if self.reduction == 'mean': + loss = loss.mean() + elif self.reduction == 'sum': + loss = loss.sum() + return {"MultiLabelAsymmetricLoss": loss} From c845f0177523a2abf7d5f398ce67d064a449b4ed Mon Sep 17 00:00:00 2001 From: thsno02 Date: Wed, 13 Sep 2023 14:05:19 +0800 Subject: [PATCH 3/4] feat: add 2 weight functions in multi-label --- ppcls/loss/dmlloss.py | 16 +++++++++++++--- ppcls/loss/multilabelloss.py | 5 ++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/ppcls/loss/dmlloss.py b/ppcls/loss/dmlloss.py index 75d3756343..820860a137 100644 --- a/ppcls/loss/dmlloss.py +++ b/ppcls/loss/dmlloss.py @@ -16,7 +16,7 @@ import paddle.nn as nn import paddle.nn.functional as F -from ppcls.loss.multilabelloss import ratio2weight_1 +from ppcls.loss.multilabelloss import ratio2weight_1, ratio2weight_2, ratio2weight_3 class DMLLoss(nn.Layer): @@ -24,7 +24,7 @@ class DMLLoss(nn.Layer): DMLLoss """ - def __init__(self, act="softmax", sum_across_class_dim=False, eps=1e-12): + def __init__(self, act="softmax", sum_across_class_dim=False, eps=1e-12, weight_type=1, weight_alpha=0.1): super().__init__() if act is not None: assert act in ["softmax", "sigmoid"] @@ -36,6 +36,8 @@ def __init__(self, act="softmax", sum_across_class_dim=False, eps=1e-12): self.act = None self.eps = eps self.sum_across_class_dim = sum_across_class_dim + self.weight_type = weight_type + self.weight_alpha = weight_alpha def _kldiv(self, x, target): class_num = x.shape[-1] @@ -54,7 +56,15 @@ def forward(self, x, target, gt_label=None): if gt_label is not None: gt_label, label_ratio = gt_label[:, 0, :], gt_label[:, 1, :] targets_mask = paddle.cast(gt_label > 0.5, 'float32') - weight = ratio2weight(targets_mask, paddle.to_tensor(label_ratio)) + if self.weight_type == 2: + weight = ratio2weight_2( + targets_mask, paddle.to_tensor(label_ratio)) + elif self.weight_type == 3: + weight = ratio2weight_3( + targets_mask, paddle.to_tensor(label_ratio), self.weight_alpha) + else: + weight = ratio2weight_1( + targets_mask, paddle.to_tensor(label_ratio)) weight = weight * (gt_label > -1) loss = loss * weight diff --git a/ppcls/loss/multilabelloss.py b/ppcls/loss/multilabelloss.py index ea37cdf344..c9a6a33685 100644 --- a/ppcls/loss/multilabelloss.py +++ b/ppcls/loss/multilabelloss.py @@ -11,7 +11,6 @@ def ratio2weight_1(targets, ratio): ``` REF: https://arxiv.org/abs/2107.03576v2 ''' - pos_weights = targets * (1. - ratio) neg_weights = (1. - targets) * ratio weights = paddle.exp(neg_weights + pos_weights) @@ -71,14 +70,14 @@ class MultiLabelLoss(nn.Layer): Multi-label loss """ - def __init__(self, epsilon=None, size_sum=False, weight_type=1, weight_ratio=False, weight_alpha=False): + def __init__(self, epsilon=None, size_sum=False, weight_ratio=False, weight_type=1, weight_alpha=0.1): super().__init__() if epsilon is not None and (epsilon <= 0 or epsilon >= 1): epsilon = None self.epsilon = epsilon + self.weight_ratio = weight_ratio self.weight_type = weight_type self.weight_alpha = weight_alpha - self.weight_ratio = weight_ratio self.size_sum = size_sum def _labelsmoothing(self, target, class_num): From 7f6c1172ea247e00a2520fecc5ec654525951bb2 Mon Sep 17 00:00:00 2001 From: thsno02 <103615460+thsno02@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:13:41 +0800 Subject: [PATCH 4/4] fix format --- ppcls/loss/multilabelloss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppcls/loss/multilabelloss.py b/ppcls/loss/multilabelloss.py index ee83a07795..963da96e9a 100644 --- a/ppcls/loss/multilabelloss.py +++ b/ppcls/loss/multilabelloss.py @@ -156,7 +156,7 @@ def forward(self, x, target): # Asymmetric Clipping and Basic CE calculation if self.clip and self.clip > 0: pt = (1 - pred_sigmoid + self.clip).clip(max=1) \ - * (1 - target) + pred_sigmoid * target + * (1 - target) + pred_sigmoid * target else: pt = (1 - pred_sigmoid) * (1 - target) + pred_sigmoid * target