Skip to content

Commit

Permalink
remove categories, fix logging
Browse files Browse the repository at this point in the history
  • Loading branch information
tomcarter23 committed Dec 5, 2024
1 parent 3bad185 commit bbb3862
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 20 deletions.
18 changes: 8 additions & 10 deletions adversarial_attack/fgsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def standard_attack(
model: torch.nn.Module,
tensor: torch.Tensor,
truth: torch.Tensor,
categories: list[str],
epsilon: float = 1e-3,
max_iter: int = 50,
) -> ty.Optional[tuple[torch.Tensor, int, int]]:
Expand All @@ -48,7 +47,6 @@ def standard_attack(
model (torch.Model): PyTorch model to attack.
tensor (torch.Tensor): Tensor to attack.
truth (torch.Tensor): Tensor representing true category.
categories (list[str]): List of categories for the model.
epsilon (float): Maximum perturbation allowed.
max_iter (int): Maximum number of iterations to perform.
Expand Down Expand Up @@ -80,7 +78,9 @@ def standard_attack(

for i in range(max_iter):
model.zero_grad()
grad = compute_gradient(model=model, input=adv_tensor, target=torch.tensor([orig_pred_idx]))
grad = compute_gradient(
model=model, input=adv_tensor, target=torch.tensor([orig_pred_idx])
)
adv_tensor = torch.clamp(adv_tensor + epsilon * grad.sign(), -2, 2)
new_pred_idx = model(adv_tensor).argmax()
if orig_pred_idx != new_pred_idx:
Expand All @@ -97,7 +97,6 @@ def targeted_attack(
tensor: torch.Tensor,
truth: torch.Tensor,
target: torch.Tensor,
categories: list[str],
epsilon: float = 1e-3,
max_iter: int = 50,
) -> ty.Optional[tuple[torch.Tensor, int, int]]:
Expand All @@ -109,7 +108,6 @@ def targeted_attack(
tensor (torch.Tensor): Tensor to attack.
truth (torch.Tensor): Tensor representing true category.
target (torch.Tensor): Tensor representing targeted category.
categories (list[str]): List of categories for the model.
epsilon (float): Maximum perturbation allowed.
max_iter (int): Maximum number of iterations to perform.
Expand All @@ -128,9 +126,11 @@ def targeted_attack(
truth_idx: int = truth.item()

if orig_pred_idx != truth_idx:
raise ValueError(
f"Model prediction {orig_pred_idx} does not match true class {truth_idx}.",
f"It is therefore pointless to perform an attack.",
logger.warning(
(
f"Model prediction {orig_pred_idx} does not match true class {truth_idx}."
f"It is therefore pointless to perform an attack.",
)
)
return None

Expand Down Expand Up @@ -170,7 +170,6 @@ def execute():
model,
tensor=tensor,
truth=truth,
categories=categories,
epsilon=epsilon,
max_iter=max_iter,
)
Expand All @@ -185,7 +184,6 @@ def execute():
tensor=tensor,
truth=truth,
target=target,
categories=categories,
epsilon=epsilon,
max_iter=max_iter,
)
Expand Down
12 changes: 2 additions & 10 deletions tests/unit/test_fgsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@ def test_compute_gradient(model, inputs_and_targets):
def test_standard_attack_success(model):
tensor = torch.tensor([[1.0, 0.0, 0.5]])
truth = torch.tensor([0])
categories = ["cat1", "cat2", "cat3"]
epsilon = 0.1
max_iter = 50

result = standard_attack(
model=model,
tensor=tensor,
truth=truth,
categories=categories,
epsilon=epsilon,
max_iter=max_iter,
)
Expand All @@ -64,7 +62,6 @@ def test_targeted_attack_success(model):
tensor = torch.tensor([[1.0, 0.0, 0.5]])
truth = torch.tensor([0])
target = torch.tensor([2])
categories = ["cat1", "cat2", "cat3"]
epsilon = 0.1
max_iter = 50

Expand All @@ -73,7 +70,6 @@ def test_targeted_attack_success(model):
tensor=tensor,
truth=truth,
target=target,
categories=categories,
epsilon=epsilon,
max_iter=max_iter,
)
Expand All @@ -86,22 +82,20 @@ def test_targeted_attack_success(model):
def test_standard_attack_failure(model):
tensor = torch.tensor([[1.0, 0.0, 0.5]])
truth = torch.tensor([1]) # Intentionally mismatched target
categories = ["cat1", "cat2", "cat3"]
epsilon = 0.1
max_iter = 50

assert standard_attack(model=model, tensor=tensor, truth=truth, categories=categories, epsilon=epsilon, max_iter=max_iter) is None
assert standard_attack(model=model, tensor=tensor, truth=truth, epsilon=epsilon, max_iter=max_iter) is None


def test_targeted_attack_failure(model):
tensor = torch.tensor([[1.0, 0.0, 0.5]])
truth = torch.tensor([1]) # Intentionally mismatched target
target = torch.tensor([2])
categories = ["cat1", "cat2", "cat3"]
epsilon = 0.1
max_iter = 50

assert targeted_attack(model=model, tensor=tensor, truth=truth, target=target, categories=categories, epsilon=epsilon, max_iter=max_iter) is None
assert targeted_attack(model=model, tensor=tensor, truth=truth, target=target, epsilon=epsilon, max_iter=max_iter) is None

def test_standard_attack_no_change(model):
tensor = torch.tensor([[1.0, 0.0, 0.5]])
Expand All @@ -114,7 +108,6 @@ def test_standard_attack_no_change(model):
model=model,
tensor=tensor,
truth=truth,
categories=categories,
epsilon=epsilon,
max_iter=max_iter,
)
Expand All @@ -138,7 +131,6 @@ def test_targeted_attack_no_change(model):
tensor=tensor,
truth=truth,
target=target,
categories=categories,
epsilon=epsilon,
max_iter=max_iter,
)
Expand Down

0 comments on commit bbb3862

Please sign in to comment.