diff --git a/adversarial_attack/fgsm.py b/adversarial_attack/fgsm.py index 1504e87..0da9e2c 100644 --- a/adversarial_attack/fgsm.py +++ b/adversarial_attack/fgsm.py @@ -126,7 +126,7 @@ def targeted_attack( grad = compute_gradient(model=model, input=adv_tensor, target=target) adv_tensor = torch.clamp(adv_tensor - epsilon * grad.sign(), -2, 2) new_pred = model(adv_tensor).argmax() - if orig_pred_idx.item() != new_pred: + if orig_pred_idx.item() == target.item(): return adv_tensor, orig_pred, new_pred warnings.warn(