diff --git a/adversarial_attack/__main__.py b/adversarial_attack/__main__.py index 492883b..16deefe 100644 --- a/adversarial_attack/__main__.py +++ b/adversarial_attack/__main__.py @@ -2,6 +2,7 @@ import sys from PIL import Image import numpy as np +import logging from adversarial_attack.resnet_utils import ( AVAILABLE_MODELS, @@ -11,9 +12,18 @@ to_array, preprocess_image, ) -from .fgsm import get_attack_fn from adversarial_attack.api import perform_attack +logger = logging.getLogger("adversarial_attack") + + +def setup_logging(level: str): + handler = logging.StreamHandler() + formatter = logging.Formatter("%(asctime)s:%(name)s:%(levelname)s\t%(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(level) + def main(): """ @@ -82,9 +92,16 @@ def main(): help="Path to save the adversarial image.", required=False, ) + parser.add_argument( + "--log", + default="WARNING", + help=f"Set the logging level. Available options: {list(logging._nameToLevel.keys())}", + ) args = parser.parse_args() + setup_logging(level=args.log) + if args.mode == "targeted" and args.category_target is None: raise ValueError("Target category is required for targeted attacks.") @@ -99,8 +116,8 @@ def main(): image=image_tensor, categories=get_model_categories(args.model), true_category=args.category_truth, - epsilon=args.epsilon, - max_iter=args.max_iterations, + epsilon=float(args.epsilon), + max_iter=int(args.max_iterations), target_category=args.category_target, ) diff --git a/adversarial_attack/fgsm.py b/adversarial_attack/fgsm.py index 0da9e2c..7eadee3 100644 --- a/adversarial_attack/fgsm.py +++ b/adversarial_attack/fgsm.py @@ -1,9 +1,12 @@ import typing as ty import torch import warnings +import logging from .resnet_utils import category_to_tensor +logger = logging.getLogger("adversarial_attack") + def compute_gradient(model: torch.nn.Module, input: torch.Tensor, target: torch.Tensor): """ @@ -52,8 +55,13 @@ def standard_attack( Returns: torch.Tensor: Adversarial image tensor or None if attack failed. """ + logger.info("Conducting standard attack") + with torch.no_grad(): orig_pred = model(tensor) + logger.debug( + f"Original prediction class: {orig_pred.argmax()}, probability: {torch.nn.functional.softmax(orig_pred, dim=1).max()}" + ) if orig_pred.argmax().item() != truth.item(): warnings.warn( @@ -70,12 +78,18 @@ def standard_attack( orig_pred_idx = torch.tensor([orig_pred.argmax().item()]) for i in range(max_iter): + logger.debug(f"Current output: {model(adv_tensor)}") model.zero_grad() grad = compute_gradient(model=model, input=adv_tensor, target=orig_pred_idx) adv_tensor = torch.clamp(adv_tensor + epsilon * grad.sign(), -2, 2) - new_pred = model(adv_tensor).argmax() + new_output = model(adv_tensor) + new_pred = new_output.argmax() + logger.debug( + f"attack iteration {i}, current prediction: {new_pred.item()}, current max probability: {torch.nn.functional.softmax(new_output, dim=1).max()}" + ) if orig_pred_idx.item() != new_pred: - return adv_tensor, orig_pred, new_pred + logger.info(f"Standard attack successful.") + return adv_tensor, orig_pred.argmax(), new_pred warnings.warn( f"Failed to alter the prediction of the model after {max_iter} tries.", @@ -108,8 +122,13 @@ def targeted_attack( Returns: torch.Tensor: Adversarial image tensor or None if attack failed. """ + logger.info("Conducting targeted attack") + with torch.no_grad(): orig_pred = model(tensor) + logger.debug( + f"Original prediction class: {orig_pred.argmax()}, probability: {torch.nn.functional.softmax(orig_pred, dim=1).max()}" + ) if orig_pred.argmax().item() != truth.item(): raise ValueError( @@ -119,15 +138,19 @@ def targeted_attack( # make a copy of the input tensor adv_tensor = tensor.clone().detach() - orig_pred_idx = torch.tensor([orig_pred.argmax().item()]) for i in range(max_iter): model.zero_grad() grad = compute_gradient(model=model, input=adv_tensor, target=target) adv_tensor = torch.clamp(adv_tensor - epsilon * grad.sign(), -2, 2) - new_pred = model(adv_tensor).argmax() - if orig_pred_idx.item() == target.item(): - return adv_tensor, orig_pred, new_pred + new_output = model(adv_tensor) + new_pred = new_output.argmax() + logger.debug( + f"Attack iteration {i}, target: {target.item()}, current prediction: {new_pred.item()}, current max probability: {torch.nn.functional.softmax(new_output, dim=1).max()}" + ) + if new_pred.item() == target.item(): + logger.info(f"Targeted attack successful.") + return adv_tensor, orig_pred.argmax(), new_pred warnings.warn( f"Failed to alter the prediction of the model after {max_iter} tries.", diff --git a/tests/unit/test_fgsm.py b/tests/unit/test_fgsm.py index 1930af2..f1f6e7b 100644 --- a/tests/unit/test_fgsm.py +++ b/tests/unit/test_fgsm.py @@ -57,7 +57,7 @@ def test_standard_attack_success(model): assert result is not None, "Attack should succeed." adv_tensor, orig_pred, new_pred = result - assert orig_pred.argmax() != new_pred, "Attack should change the model prediction." + assert new_pred.item() != orig_pred.item(), "Attack should change the model prediction." def test_targeted_attack_success(model): @@ -80,9 +80,7 @@ def test_targeted_attack_success(model): assert result is not None, "Attack should succeed." adv_tensor, orig_pred, new_pred = result - assert ( - orig_pred.argmax() != 2 - ), "Attack should change the model prediction to target." + assert new_pred.item() == 2, "Attack should change the model prediction to target." def test_standard_attack_failure(model):