Skip to content

Commit

Permalink
Merge branch 'master' into param-e2e
Browse files Browse the repository at this point in the history
  • Loading branch information
tomcarter23 authored Dec 4, 2024
2 parents f64deda + 089eb85 commit 755a987
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 13 deletions.
23 changes: 20 additions & 3 deletions adversarial_attack/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
from PIL import Image
import numpy as np
import logging

from adversarial_attack.resnet_utils import (
AVAILABLE_MODELS,
Expand All @@ -11,9 +12,18 @@
to_array,
preprocess_image,
)
from .fgsm import get_attack_fn
from adversarial_attack.api import perform_attack

logger = logging.getLogger("adversarial_attack")


def setup_logging(level: str):
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s:%(name)s:%(levelname)s\t%(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(level)


def main():
"""
Expand Down Expand Up @@ -82,9 +92,16 @@ def main():
help="Path to save the adversarial image.",
required=False,
)
parser.add_argument(
"--log",
default="WARNING",
help=f"Set the logging level. Available options: {list(logging._nameToLevel.keys())}",
)

args = parser.parse_args()

setup_logging(level=args.log)

if args.mode == "targeted" and args.category_target is None:
raise ValueError("Target category is required for targeted attacks.")

Expand All @@ -99,8 +116,8 @@ def main():
image=image_tensor,
categories=get_model_categories(args.model),
true_category=args.category_truth,
epsilon=args.epsilon,
max_iter=args.max_iterations,
epsilon=float(args.epsilon),
max_iter=int(args.max_iterations),
target_category=args.category_target,
)

Expand Down
35 changes: 29 additions & 6 deletions adversarial_attack/fgsm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import typing as ty
import torch
import warnings
import logging

from .resnet_utils import category_to_tensor

logger = logging.getLogger("adversarial_attack")


def compute_gradient(model: torch.nn.Module, input: torch.Tensor, target: torch.Tensor):
"""
Expand Down Expand Up @@ -52,8 +55,13 @@ def standard_attack(
Returns:
torch.Tensor: Adversarial image tensor or None if attack failed.
"""
logger.info("Conducting standard attack")

with torch.no_grad():
orig_pred = model(tensor)
logger.debug(
f"Original prediction class: {orig_pred.argmax()}, probability: {torch.nn.functional.softmax(orig_pred, dim=1).max()}"
)

if orig_pred.argmax().item() != truth.item():
warnings.warn(
Expand All @@ -70,12 +78,18 @@ def standard_attack(
orig_pred_idx = torch.tensor([orig_pred.argmax().item()])

for i in range(max_iter):
logger.debug(f"Current output: {model(adv_tensor)}")
model.zero_grad()
grad = compute_gradient(model=model, input=adv_tensor, target=orig_pred_idx)
adv_tensor = torch.clamp(adv_tensor + epsilon * grad.sign(), -2, 2)
new_pred = model(adv_tensor).argmax()
new_output = model(adv_tensor)
new_pred = new_output.argmax()
logger.debug(
f"attack iteration {i}, current prediction: {new_pred.item()}, current max probability: {torch.nn.functional.softmax(new_output, dim=1).max()}"
)
if orig_pred_idx.item() != new_pred:
return adv_tensor, orig_pred, new_pred
logger.info(f"Standard attack successful.")
return adv_tensor, orig_pred.argmax(), new_pred

warnings.warn(
f"Failed to alter the prediction of the model after {max_iter} tries.",
Expand Down Expand Up @@ -108,8 +122,13 @@ def targeted_attack(
Returns:
torch.Tensor: Adversarial image tensor or None if attack failed.
"""
logger.info("Conducting targeted attack")

with torch.no_grad():
orig_pred = model(tensor)
logger.debug(
f"Original prediction class: {orig_pred.argmax()}, probability: {torch.nn.functional.softmax(orig_pred, dim=1).max()}"
)

if orig_pred.argmax().item() != truth.item():
raise ValueError(
Expand All @@ -119,15 +138,19 @@ def targeted_attack(

# make a copy of the input tensor
adv_tensor = tensor.clone().detach()
orig_pred_idx = torch.tensor([orig_pred.argmax().item()])

for i in range(max_iter):
model.zero_grad()
grad = compute_gradient(model=model, input=adv_tensor, target=target)
adv_tensor = torch.clamp(adv_tensor - epsilon * grad.sign(), -2, 2)
new_pred = model(adv_tensor).argmax()
if orig_pred_idx.item() == target.item():
return adv_tensor, orig_pred, new_pred
new_output = model(adv_tensor)
new_pred = new_output.argmax()
logger.debug(
f"Attack iteration {i}, target: {target.item()}, current prediction: {new_pred.item()}, current max probability: {torch.nn.functional.softmax(new_output, dim=1).max()}"
)
if new_pred.item() == target.item():
logger.info(f"Targeted attack successful.")
return adv_tensor, orig_pred.argmax(), new_pred

warnings.warn(
f"Failed to alter the prediction of the model after {max_iter} tries.",
Expand Down
6 changes: 2 additions & 4 deletions tests/unit/test_fgsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_standard_attack_success(model):

assert result is not None, "Attack should succeed."
adv_tensor, orig_pred, new_pred = result
assert orig_pred.argmax() != new_pred, "Attack should change the model prediction."
assert new_pred.item() != orig_pred.item(), "Attack should change the model prediction."


def test_targeted_attack_success(model):
Expand All @@ -80,9 +80,7 @@ def test_targeted_attack_success(model):

assert result is not None, "Attack should succeed."
adv_tensor, orig_pred, new_pred = result
assert (
orig_pred.argmax() != 2
), "Attack should change the model prediction to target."
assert new_pred.item() == 2, "Attack should change the model prediction to target."


def test_standard_attack_failure(model):
Expand Down

0 comments on commit 755a987

Please sign in to comment.