Skip to content

Commit

Permalink
Parameterise e2e (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomcarter23 authored Dec 4, 2024
1 parent 089eb85 commit b8e9072
Show file tree
Hide file tree
Showing 15 changed files with 278 additions and 66 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/e2e_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e .[test]
- name: Run Tests
- name: Run Standard FGSM Tests
run: |
pytest ./tests/e2e/
pytest ./tests/e2e/test_end2end.py::test_perform_attack_standard -v --log-cli-level=INFO
- name: Run Targeted FGSM Tests
run: |
pytest ./tests/e2e/test_end2end.py::test_perform_attack_targeted -v --log-cli-level=INFO
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ jobs:
pip install -e .[test]
- name: Run Tests
run: |
pytest ./tests/unit/
pytest ./tests/unit/ -v
14 changes: 10 additions & 4 deletions adversarial_attack/fgsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def standard_attack(
model: torch.nn.Module,
tensor: torch.Tensor,
truth: torch.Tensor,
categories: list[str],
epsilon: float = 1e-3,
max_iter: int = 50,
) -> ty.Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
Expand All @@ -47,6 +48,7 @@ def standard_attack(
model (torch.Model): PyTorch model to attack.
tensor (torch.Tensor): Tensor to attack.
truth (torch.Tensor): Tensor representing true category.
categories (list[str]): List of categories for the model.
epsilon (float): Maximum perturbation allowed.
max_iter (int): Maximum number of iterations to perform.
Expand All @@ -64,8 +66,8 @@ def standard_attack(
if orig_pred.argmax().item() != truth.item():
warnings.warn(
(
f"Model prediction {orig_pred.argmax().item()} does not match true class {truth.item()}."
f"It is therefore pointless to perform an attack."
f"Model prediction `{categories[orig_pred.argmax().item()]}` does not match true class `{categories[truth.item()]}`."
f"It is therefore pointless to perform an attack. Not attacking."
),
RuntimeWarning,
)
Expand Down Expand Up @@ -101,6 +103,7 @@ def targeted_attack(
tensor: torch.Tensor,
truth: torch.Tensor,
target: torch.Tensor,
categories: list[str],
epsilon: float = 1e-3,
max_iter: int = 50,
) -> ty.Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
Expand All @@ -112,6 +115,7 @@ def targeted_attack(
tensor (torch.Tensor): Tensor to attack.
truth (torch.Tensor): Tensor representing true category.
target (torch.Tensor): Tensor representing targeted category.
categories (list[str]): List of categories for the model.
epsilon (float): Maximum perturbation allowed.
max_iter (int): Maximum number of iterations to perform.
Expand All @@ -128,8 +132,8 @@ def targeted_attack(

if orig_pred.argmax().item() != truth.item():
raise ValueError(
f"Model prediction {orig_pred.argmax().item()} does not match true class {truth.item()}.",
f"It is therefore pointless to perform an attack.",
f"Model prediction {categories[orig_pred.argmax().item()]} does not match true class {categories[truth.item()]}.",
f"It is therefore pointless to perform an attack. Not attacking.",
)

# make a copy of the input tensor
Expand Down Expand Up @@ -174,6 +178,7 @@ def execute():
model,
tensor=tensor,
truth=truth,
categories=categories,
epsilon=epsilon,
max_iter=max_iter,
)
Expand All @@ -188,6 +193,7 @@ def execute():
tensor=tensor,
truth=truth,
target=target,
categories=categories,
epsilon=epsilon,
max_iter=max_iter,
)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/e2e/input/hare_ILSVRC2012_val_00004064.JPEG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
216 changes: 173 additions & 43 deletions tests/e2e/test_end2end.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,178 @@
import pytest

from adversarial_attack.api import perform_attack
from adversarial_attack.resnet_utils import load_model_default_weights, preprocess_image, get_model_categories, load_image


@pytest.fixture
def image_truth():
return "./tests/e2e/input/lionfish_ILSVRC2012_val_00019791.JPEG", "lionfish"


@pytest.mark.parametrize("model_name", ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"])
def test_perform_attack_standard(model_name, image_truth):
image, true_category = image_truth
model = load_model_default_weights(model_name)
input_image = preprocess_image(load_image(image))
categories = get_model_categories(model_name)
result = perform_attack(
model=model,
mode="standard",
image=input_image,
categories=categories,
true_category=true_category,
epsilon=1.0e-3,
max_iter=50,
)
from adversarial_attack.resnet_utils import (
load_model_default_weights,
preprocess_image,
get_model_categories,
load_image,
)

import pytest
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def test_perform_attack_standard():
test_cases = [
("./tests/e2e/input/beaker_ILSVRC2012_val_00001780.JPEG", "beaker"),
("./tests/e2e/input/doormat_ILSVRC2012_val_00030383.JPEG", "doormat"),
("./tests/e2e/input/hare_ILSVRC2012_val_00004064.JPEG", "hare"),
(
"./tests/e2e/input/jack-o'-lantern_ILSVRC2012_val_00030955.JPEG",
"jack-o'-lantern",
),
("./tests/e2e/input/lawn_mower_ILSVRC2012_val_00020327.JPEG", "lawn mower"),
("./tests/e2e/input/lionfish_ILSVRC2012_val_00019791.JPEG", "lionfish"),
("./tests/e2e/input/monarch_ILSVRC2012_val_00002935.JPEG", "monarch"),
("./tests/e2e/input/pickelhaube_ILSVRC2012_val_00018444.JPEG", "pickelhaube"),
("./tests/e2e/input/sea_urchin_ILSVRC2012_val_00028454.JPEG", "sea urchin"),
]

models = ["resnet50", "resnet101", "resnet152"]

total_combinations = len(models) * len(test_cases)

total_tests = 0
success_count = 0

logger.info("Starting test for perform_attack_standard...")

for model_name in models:
logger.info(f"Testing model: {model_name}")
for image_path, true_category in test_cases:
try:
progress_percentage = (total_tests / total_combinations) * 100
logger.info(
f"Progress: {total_tests}/{total_combinations} ({progress_percentage:.2f}%) - Running test for image '{image_path}' with true category '{true_category}'"
)

model = load_model_default_weights(model_name)
input_image = preprocess_image(load_image(image_path))
categories = get_model_categories(model_name)

result = perform_attack(
model=model,
mode="standard",
image=input_image,
categories=categories,
true_category=true_category,
epsilon=1.0e-1,
max_iter=10,
)

total_tests += 1
if result is not None:
success_count += 1
else:
logger.warning(
f"Test failed for model '{model_name}', image '{image_path}', true category '{true_category}'"
)

assert result is not None


@pytest.mark.parametrize("target_category", ["goldfish", "monarch"])
@pytest.mark.parametrize("model_name", ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"])
def test_perform_attack_standard(model_name, target_category, image_truth):
image, true_category = image_truth
model = load_model_default_weights(model_name)
input_image = preprocess_image(load_image(image))
categories = get_model_categories(model_name)
result = perform_attack(
model=model,
mode="targeted",
image=input_image,
categories=categories,
true_category=true_category,
target_category=target_category,
epsilon=1.0e-3,
max_iter=50,
except Exception as e:
logger.error(
f"Error occurred for model '{model_name}', image '{image_path}', true category '{true_category}': {e}"
)
total_tests += 1 # Count this as a test to avoid skewing success rate

success_rate = success_count / total_tests if total_tests > 0 else 0
logger.info(
f"Completed all tests. Success rate: {success_rate:.2%} ({success_count}/{total_tests})"
)
assert (
success_rate >= 0.75
), f"Success rate {success_rate:.2%} is below the required threshold of 75%."


def test_perform_attack_targeted():
test_cases = [
("./tests/e2e/input/beaker_ILSVRC2012_val_00001780.JPEG", "beaker"),
("./tests/e2e/input/doormat_ILSVRC2012_val_00030383.JPEG", "doormat"),
("./tests/e2e/input/hare_ILSVRC2012_val_00004064.JPEG", "hare"),
(
"./tests/e2e/input/jack-o'-lantern_ILSVRC2012_val_00030955.JPEG",
"jack-o'-lantern",
),
("./tests/e2e/input/lawn_mower_ILSVRC2012_val_00020327.JPEG", "lawn mower"),
("./tests/e2e/input/lionfish_ILSVRC2012_val_00019791.JPEG", "lionfish"),
("./tests/e2e/input/monarch_ILSVRC2012_val_00002935.JPEG", "monarch"),
("./tests/e2e/input/pickelhaube_ILSVRC2012_val_00018444.JPEG", "pickelhaube"),
("./tests/e2e/input/sea_urchin_ILSVRC2012_val_00028454.JPEG", "sea urchin"),
]

target_categories = [
"beaker",
"bookcase",
"doormat",
"hare",
"jack-o'-lantern",
"lawn mower",
"lionfish",
"monarch",
"pickelhaube",
"sea urchin",
]

models = ["resnet50", "resnet101", "resnet152"]

total_combinations = len(models) * len(target_categories) * len(test_cases)

total_tests = 0
success_count = 0

assert result is not None
logger.info("Starting test for perform_attack_targeted...")

for model_name in models:
logger.info(f"Testing model: {model_name}")
for target_category in target_categories:
logger.info(f"Testing target category: {target_category}")
for image_path, true_category in test_cases:
try:
progress_percentage = (total_tests / total_combinations) * 100
logger.info(
f"Progress: {total_tests}/{total_combinations} ({progress_percentage:.2f}%) - Running test for image '{image_path}' with true category '{true_category}' targeting '{target_category}'"
)

model = load_model_default_weights(model_name)
input_image = preprocess_image(load_image(image_path))
categories = get_model_categories(model_name)

result = perform_attack(
model=model,
mode="targeted",
image=input_image,
categories=categories,
true_category=true_category,
target_category=target_category,
epsilon=1.0e-1,
max_iter=10,
)

total_tests += 1
if result is not None:
success_count += 1
else:
logger.warning(
f"Test failed for model '{model_name}', image '{image_path}', "
f"true category '{true_category}', targeting '{target_category}'"
)

except Exception as e:
logger.error(
f"Error occurred for model '{model_name}', image '{image_path}', "
f"true category '{true_category}', targeting '{target_category}': {e}"
)
total_tests += (
1 # Count this as a test to avoid skewing success rate
)

success_rate = success_count / total_tests if total_tests > 0 else 0
logger.info(
f"Completed all tests. Success rate: {success_rate:.2%} ({success_count}/{total_tests})"
)
assert (
success_rate >= 0.75
), f"Success rate {success_rate:.2%} is below the required threshold of 75%."
Loading

0 comments on commit b8e9072

Please sign in to comment.