Skip to content

Commit

Permalink
The code on momentum in the original mi-fgsm is complex and lacks cor…
Browse files Browse the repository at this point in the history
…respondence with the pseudo-code in the paper
  • Loading branch information
rikonaka committed Apr 15, 2024
1 parent 53d35ec commit 58d55d4
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions torchattacks/attacks/mifgsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,13 @@ def forward(self, images, labels):
if self.targeted:
target_labels = self.get_target_label(images, labels)

momentum = torch.zeros_like(images).detach().to(self.device)

loss = nn.CrossEntropyLoss()

adv_images = images.clone().detach()
adv_images.requires_grad = True
momentum = torch.zeros_like(images).detach().to(self.device)

for _ in range(self.steps):
adv_images.requires_grad = True
outputs = self.get_logits(adv_images)

# Calculate loss
Expand All @@ -70,11 +69,11 @@ def forward(self, images, labels):
)[0]

grad = grad / torch.mean(torch.abs(grad), dim=(1, 2, 3), keepdim=True)
grad = grad + momentum * self.decay
momentum = grad
momentum = self.decay * momentum + grad

adv_images = adv_images.detach() + self.alpha * grad.sign()
adv_images = adv_images + self.alpha * torch.sign(momentum)
delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)
adv_images = torch.clamp(images + delta, min=0, max=1).detach()
adv_images = images + delta

adv_images = torch.clamp(images + delta, min=0, max=1)
return adv_images

0 comments on commit 58d55d4

Please sign in to comment.