Skip to content

Commit

Permalink
Merge branch 'master' into minor-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
tomcarter23 authored Dec 4, 2024
2 parents 936319b + b8e9072 commit e41dfbf
Show file tree
Hide file tree
Showing 16 changed files with 308 additions and 65 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/e2e_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e .[test]
- name: Run Tests
- name: Run Standard FGSM Tests
run: |
pytest ./tests/e2e/
pytest ./tests/e2e/test_end2end.py::test_perform_attack_standard -v --log-cli-level=INFO
- name: Run Targeted FGSM Tests
run: |
pytest ./tests/e2e/test_end2end.py::test_perform_attack_targeted -v --log-cli-level=INFO
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ jobs:
pip install -e .[test]
- name: Run Tests
run: |
pytest ./tests/unit/
pytest ./tests/unit/ -v
19 changes: 18 additions & 1 deletion adversarial_attack/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
from PIL import Image
import numpy as np
import logging

from adversarial_attack.resnet_utils import (
AVAILABLE_MODELS,
Expand All @@ -11,9 +12,18 @@
to_array,
preprocess_image,
)
from .fgsm import get_attack_fn
from adversarial_attack.api import perform_attack

logger = logging.getLogger("adversarial_attack")


def setup_logging(level: str):
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s:%(name)s:%(levelname)s\t%(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(level)


def main():
"""
Expand Down Expand Up @@ -82,9 +92,16 @@ def main():
help="Path to save the adversarial image.",
required=False,
)
parser.add_argument(
"--log",
default="WARNING",
help=f"Set the logging level. Available options: {list(logging._nameToLevel.keys())}",
)

args = parser.parse_args()

setup_logging(level=args.log)

if args.mode == "targeted" and args.category_target is None:
raise ValueError("Target category is required for targeted attacks.")

Expand Down
20 changes: 20 additions & 0 deletions adversarial_attack/fgsm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import typing as ty
import torch
import warnings
import logging

from .resnet_utils import category_to_tensor

logger = logging.getLogger("adversarial_attack")


def compute_gradient(model: torch.nn.Module, input: torch.Tensor, target: torch.Tensor):
"""
Expand Down Expand Up @@ -34,6 +37,7 @@ def standard_attack(
model: torch.nn.Module,
tensor: torch.Tensor,
truth: torch.Tensor,
categories: list[str],
epsilon: float = 1e-3,
max_iter: int = 50,
) -> ty.Optional[tuple[torch.Tensor, int, int]]:
Expand All @@ -44,14 +48,20 @@ def standard_attack(
model (torch.Model): PyTorch model to attack.
tensor (torch.Tensor): Tensor to attack.
truth (torch.Tensor): Tensor representing true category.
categories (list[str]): List of categories for the model.
epsilon (float): Maximum perturbation allowed.
max_iter (int): Maximum number of iterations to perform.
Returns:
torch.Tensor: Adversarial image tensor or None if attack failed.
"""
logger.info("Conducting standard attack")

with torch.no_grad():
orig_pred = model(tensor)
logger.debug(
f"Original prediction class: {orig_pred.argmax()}, probability: {torch.nn.functional.softmax(orig_pred, dim=1).max()}"
)

orig_pred_idx: int = orig_pred.argmax().item()
truth_idx: int = truth.item()
Expand All @@ -70,6 +80,7 @@ def standard_attack(
adv_tensor = tensor.clone().detach()

for i in range(max_iter):
logger.debug(f"Current output: {model(adv_tensor)}")
model.zero_grad()
grad = compute_gradient(model=model, input=adv_tensor, target=torch.tensor([orig_pred_idx]))
adv_tensor = torch.clamp(adv_tensor + epsilon * grad.sign(), -2, 2)
Expand All @@ -89,6 +100,7 @@ def targeted_attack(
tensor: torch.Tensor,
truth: torch.Tensor,
target: torch.Tensor,
categories: list[str],
epsilon: float = 1e-3,
max_iter: int = 50,
) -> ty.Optional[tuple[torch.Tensor, int, int]]:
Expand All @@ -100,14 +112,20 @@ def targeted_attack(
tensor (torch.Tensor): Tensor to attack.
truth (torch.Tensor): Tensor representing true category.
target (torch.Tensor): Tensor representing targeted category.
categories (list[str]): List of categories for the model.
epsilon (float): Maximum perturbation allowed.
max_iter (int): Maximum number of iterations to perform.
Returns:
torch.Tensor: Adversarial image tensor or None if attack failed.
"""
logger.info("Conducting targeted attack")

with torch.no_grad():
orig_pred = model(tensor)
logger.debug(
f"Original prediction class: {orig_pred.argmax()}, probability: {torch.nn.functional.softmax(orig_pred, dim=1).max()}"
)

orig_pred_idx: int = orig_pred.argmax().item()
truth_idx: int = truth.item()
Expand Down Expand Up @@ -155,6 +173,7 @@ def execute():
model,
tensor=tensor,
truth=truth,
categories=categories,
epsilon=epsilon,
max_iter=max_iter,
)
Expand All @@ -169,6 +188,7 @@ def execute():
tensor=tensor,
truth=truth,
target=target,
categories=categories,
epsilon=epsilon,
max_iter=max_iter,
)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/e2e/input/hare_ILSVRC2012_val_00004064.JPEG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
216 changes: 173 additions & 43 deletions tests/e2e/test_end2end.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,178 @@
import pytest

from adversarial_attack.api import perform_attack
from adversarial_attack.resnet_utils import load_model_default_weights, preprocess_image, get_model_categories, load_image


@pytest.fixture
def image_truth():
return "./tests/e2e/input/lionfish_ILSVRC2012_val_00019791.JPEG", "lionfish"


@pytest.mark.parametrize("model_name", ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"])
def test_perform_attack_standard(model_name, image_truth):
image, true_category = image_truth
model = load_model_default_weights(model_name)
input_image = preprocess_image(load_image(image))
categories = get_model_categories(model_name)
result = perform_attack(
model=model,
mode="standard",
image=input_image,
categories=categories,
true_category=true_category,
epsilon=1.0e-3,
max_iter=50,
)
from adversarial_attack.resnet_utils import (
load_model_default_weights,
preprocess_image,
get_model_categories,
load_image,
)

import pytest
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def test_perform_attack_standard():
test_cases = [
("./tests/e2e/input/beaker_ILSVRC2012_val_00001780.JPEG", "beaker"),
("./tests/e2e/input/doormat_ILSVRC2012_val_00030383.JPEG", "doormat"),
("./tests/e2e/input/hare_ILSVRC2012_val_00004064.JPEG", "hare"),
(
"./tests/e2e/input/jack-o'-lantern_ILSVRC2012_val_00030955.JPEG",
"jack-o'-lantern",
),
("./tests/e2e/input/lawn_mower_ILSVRC2012_val_00020327.JPEG", "lawn mower"),
("./tests/e2e/input/lionfish_ILSVRC2012_val_00019791.JPEG", "lionfish"),
("./tests/e2e/input/monarch_ILSVRC2012_val_00002935.JPEG", "monarch"),
("./tests/e2e/input/pickelhaube_ILSVRC2012_val_00018444.JPEG", "pickelhaube"),
("./tests/e2e/input/sea_urchin_ILSVRC2012_val_00028454.JPEG", "sea urchin"),
]

models = ["resnet50", "resnet101", "resnet152"]

total_combinations = len(models) * len(test_cases)

total_tests = 0
success_count = 0

logger.info("Starting test for perform_attack_standard...")

for model_name in models:
logger.info(f"Testing model: {model_name}")
for image_path, true_category in test_cases:
try:
progress_percentage = (total_tests / total_combinations) * 100
logger.info(
f"Progress: {total_tests}/{total_combinations} ({progress_percentage:.2f}%) - Running test for image '{image_path}' with true category '{true_category}'"
)

model = load_model_default_weights(model_name)
input_image = preprocess_image(load_image(image_path))
categories = get_model_categories(model_name)

result = perform_attack(
model=model,
mode="standard",
image=input_image,
categories=categories,
true_category=true_category,
epsilon=1.0e-1,
max_iter=10,
)

total_tests += 1
if result is not None:
success_count += 1
else:
logger.warning(
f"Test failed for model '{model_name}', image '{image_path}', true category '{true_category}'"
)

assert result is not None


@pytest.mark.parametrize("target_category", ["goldfish", "monarch"])
@pytest.mark.parametrize("model_name", ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"])
def test_perform_attack_standard(model_name, target_category, image_truth):
image, true_category = image_truth
model = load_model_default_weights(model_name)
input_image = preprocess_image(load_image(image))
categories = get_model_categories(model_name)
result = perform_attack(
model=model,
mode="targeted",
image=input_image,
categories=categories,
true_category=true_category,
target_category=target_category,
epsilon=1.0e-3,
max_iter=50,
except Exception as e:
logger.error(
f"Error occurred for model '{model_name}', image '{image_path}', true category '{true_category}': {e}"
)
total_tests += 1 # Count this as a test to avoid skewing success rate

success_rate = success_count / total_tests if total_tests > 0 else 0
logger.info(
f"Completed all tests. Success rate: {success_rate:.2%} ({success_count}/{total_tests})"
)
assert (
success_rate >= 0.75
), f"Success rate {success_rate:.2%} is below the required threshold of 75%."


def test_perform_attack_targeted():
test_cases = [
("./tests/e2e/input/beaker_ILSVRC2012_val_00001780.JPEG", "beaker"),
("./tests/e2e/input/doormat_ILSVRC2012_val_00030383.JPEG", "doormat"),
("./tests/e2e/input/hare_ILSVRC2012_val_00004064.JPEG", "hare"),
(
"./tests/e2e/input/jack-o'-lantern_ILSVRC2012_val_00030955.JPEG",
"jack-o'-lantern",
),
("./tests/e2e/input/lawn_mower_ILSVRC2012_val_00020327.JPEG", "lawn mower"),
("./tests/e2e/input/lionfish_ILSVRC2012_val_00019791.JPEG", "lionfish"),
("./tests/e2e/input/monarch_ILSVRC2012_val_00002935.JPEG", "monarch"),
("./tests/e2e/input/pickelhaube_ILSVRC2012_val_00018444.JPEG", "pickelhaube"),
("./tests/e2e/input/sea_urchin_ILSVRC2012_val_00028454.JPEG", "sea urchin"),
]

target_categories = [
"beaker",
"bookcase",
"doormat",
"hare",
"jack-o'-lantern",
"lawn mower",
"lionfish",
"monarch",
"pickelhaube",
"sea urchin",
]

models = ["resnet50", "resnet101", "resnet152"]

total_combinations = len(models) * len(target_categories) * len(test_cases)

total_tests = 0
success_count = 0

assert result is not None
logger.info("Starting test for perform_attack_targeted...")

for model_name in models:
logger.info(f"Testing model: {model_name}")
for target_category in target_categories:
logger.info(f"Testing target category: {target_category}")
for image_path, true_category in test_cases:
try:
progress_percentage = (total_tests / total_combinations) * 100
logger.info(
f"Progress: {total_tests}/{total_combinations} ({progress_percentage:.2f}%) - Running test for image '{image_path}' with true category '{true_category}' targeting '{target_category}'"
)

model = load_model_default_weights(model_name)
input_image = preprocess_image(load_image(image_path))
categories = get_model_categories(model_name)

result = perform_attack(
model=model,
mode="targeted",
image=input_image,
categories=categories,
true_category=true_category,
target_category=target_category,
epsilon=1.0e-1,
max_iter=10,
)

total_tests += 1
if result is not None:
success_count += 1
else:
logger.warning(
f"Test failed for model '{model_name}', image '{image_path}', "
f"true category '{true_category}', targeting '{target_category}'"
)

except Exception as e:
logger.error(
f"Error occurred for model '{model_name}', image '{image_path}', "
f"true category '{true_category}', targeting '{target_category}': {e}"
)
total_tests += (
1 # Count this as a test to avoid skewing success rate
)

success_rate = success_count / total_tests if total_tests > 0 else 0
logger.info(
f"Completed all tests. Success rate: {success_rate:.2%} ({success_count}/{total_tests})"
)
assert (
success_rate >= 0.75
), f"Success rate {success_rate:.2%} is below the required threshold of 75%."
Loading

0 comments on commit e41dfbf

Please sign in to comment.