diff --git a/torchattacks/attacks/mifgsm.py b/torchattacks/attacks/mifgsm.py index fc65d191..18256149 100644 --- a/torchattacks/attacks/mifgsm.py +++ b/torchattacks/attacks/mifgsm.py @@ -48,14 +48,13 @@ def forward(self, images, labels): if self.targeted: target_labels = self.get_target_label(images, labels) - momentum = torch.zeros_like(images).detach().to(self.device) loss = nn.CrossEntropyLoss() - adv_images = images.clone().detach() + adv_images.requires_grad = True + momentum = torch.zeros_like(images).detach().to(self.device) for _ in range(self.steps): - adv_images.requires_grad = True outputs = self.get_logits(adv_images) # Calculate loss @@ -70,11 +69,11 @@ def forward(self, images, labels): )[0] grad = grad / torch.mean(torch.abs(grad), dim=(1, 2, 3), keepdim=True) - grad = grad + momentum * self.decay - momentum = grad + momentum = self.decay * momentum + grad - adv_images = adv_images.detach() + self.alpha * grad.sign() + adv_images = adv_images + self.alpha * torch.sign(momentum) delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps) - adv_images = torch.clamp(images + delta, min=0, max=1).detach() + adv_images = images + delta + adv_images = torch.clamp(images + delta, min=0, max=1) return adv_images