diff --git a/adversarial_attack/fgsm.py b/adversarial_attack/fgsm.py index c38e401..040f03e 100644 --- a/adversarial_attack/fgsm.py +++ b/adversarial_attack/fgsm.py @@ -37,7 +37,6 @@ def standard_attack( model: torch.nn.Module, tensor: torch.Tensor, truth: torch.Tensor, - categories: list[str], epsilon: float = 1e-3, max_iter: int = 50, ) -> ty.Optional[tuple[torch.Tensor, int, int]]: @@ -48,7 +47,6 @@ def standard_attack( model (torch.Model): PyTorch model to attack. tensor (torch.Tensor): Tensor to attack. truth (torch.Tensor): Tensor representing true category. - categories (list[str]): List of categories for the model. epsilon (float): Maximum perturbation allowed. max_iter (int): Maximum number of iterations to perform. @@ -80,7 +78,9 @@ def standard_attack( for i in range(max_iter): model.zero_grad() - grad = compute_gradient(model=model, input=adv_tensor, target=torch.tensor([orig_pred_idx])) + grad = compute_gradient( + model=model, input=adv_tensor, target=torch.tensor([orig_pred_idx]) + ) adv_tensor = torch.clamp(adv_tensor + epsilon * grad.sign(), -2, 2) new_pred_idx = model(adv_tensor).argmax() if orig_pred_idx != new_pred_idx: @@ -97,7 +97,6 @@ def targeted_attack( tensor: torch.Tensor, truth: torch.Tensor, target: torch.Tensor, - categories: list[str], epsilon: float = 1e-3, max_iter: int = 50, ) -> ty.Optional[tuple[torch.Tensor, int, int]]: @@ -109,7 +108,6 @@ def targeted_attack( tensor (torch.Tensor): Tensor to attack. truth (torch.Tensor): Tensor representing true category. target (torch.Tensor): Tensor representing targeted category. - categories (list[str]): List of categories for the model. epsilon (float): Maximum perturbation allowed. max_iter (int): Maximum number of iterations to perform. @@ -128,9 +126,11 @@ def targeted_attack( truth_idx: int = truth.item() if orig_pred_idx != truth_idx: - raise ValueError( - f"Model prediction {orig_pred_idx} does not match true class {truth_idx}.", - f"It is therefore pointless to perform an attack.", + logger.warning( + ( + f"Model prediction {orig_pred_idx} does not match true class {truth_idx}." + f"It is therefore pointless to perform an attack.", + ) ) return None @@ -170,7 +170,6 @@ def execute(): model, tensor=tensor, truth=truth, - categories=categories, epsilon=epsilon, max_iter=max_iter, ) @@ -185,7 +184,6 @@ def execute(): tensor=tensor, truth=truth, target=target, - categories=categories, epsilon=epsilon, max_iter=max_iter, ) diff --git a/tests/unit/test_fgsm.py b/tests/unit/test_fgsm.py index 1bd8f10..337971e 100644 --- a/tests/unit/test_fgsm.py +++ b/tests/unit/test_fgsm.py @@ -42,7 +42,6 @@ def test_compute_gradient(model, inputs_and_targets): def test_standard_attack_success(model): tensor = torch.tensor([[1.0, 0.0, 0.5]]) truth = torch.tensor([0]) - categories = ["cat1", "cat2", "cat3"] epsilon = 0.1 max_iter = 50 @@ -50,7 +49,6 @@ def test_standard_attack_success(model): model=model, tensor=tensor, truth=truth, - categories=categories, epsilon=epsilon, max_iter=max_iter, ) @@ -64,7 +62,6 @@ def test_targeted_attack_success(model): tensor = torch.tensor([[1.0, 0.0, 0.5]]) truth = torch.tensor([0]) target = torch.tensor([2]) - categories = ["cat1", "cat2", "cat3"] epsilon = 0.1 max_iter = 50 @@ -73,7 +70,6 @@ def test_targeted_attack_success(model): tensor=tensor, truth=truth, target=target, - categories=categories, epsilon=epsilon, max_iter=max_iter, ) @@ -86,22 +82,20 @@ def test_targeted_attack_success(model): def test_standard_attack_failure(model): tensor = torch.tensor([[1.0, 0.0, 0.5]]) truth = torch.tensor([1]) # Intentionally mismatched target - categories = ["cat1", "cat2", "cat3"] epsilon = 0.1 max_iter = 50 - assert standard_attack(model=model, tensor=tensor, truth=truth, categories=categories, epsilon=epsilon, max_iter=max_iter) is None + assert standard_attack(model=model, tensor=tensor, truth=truth, epsilon=epsilon, max_iter=max_iter) is None def test_targeted_attack_failure(model): tensor = torch.tensor([[1.0, 0.0, 0.5]]) truth = torch.tensor([1]) # Intentionally mismatched target target = torch.tensor([2]) - categories = ["cat1", "cat2", "cat3"] epsilon = 0.1 max_iter = 50 - assert targeted_attack(model=model, tensor=tensor, truth=truth, target=target, categories=categories, epsilon=epsilon, max_iter=max_iter) is None + assert targeted_attack(model=model, tensor=tensor, truth=truth, target=target, epsilon=epsilon, max_iter=max_iter) is None def test_standard_attack_no_change(model): tensor = torch.tensor([[1.0, 0.0, 0.5]]) @@ -114,7 +108,6 @@ def test_standard_attack_no_change(model): model=model, tensor=tensor, truth=truth, - categories=categories, epsilon=epsilon, max_iter=max_iter, ) @@ -138,7 +131,6 @@ def test_targeted_attack_no_change(model): tensor=tensor, truth=truth, target=target, - categories=categories, epsilon=epsilon, max_iter=max_iter, )