-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
13a77a4
commit 576339b
Showing
14 changed files
with
112 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
Gradient-norm attack. Proposed for MIA in multiple settings, and particularly experimented for pre-training data and LLMs in https://arxiv.org/abs/2402.17012 | ||
""" | ||
|
||
import torch as ch | ||
import numpy as np | ||
from mimir.attacks.all_attacks import Attack | ||
from mimir.models import Model | ||
from mimir.config import ExperimentConfig | ||
|
||
|
||
class GradNormAttack(Attack): | ||
def __init__(self, config: ExperimentConfig, model: Model): | ||
super().__init__(config, model, ref_model=None, is_blackbox=False) | ||
|
||
def _attack(self, document, probs, tokens=None, **kwargs): | ||
""" | ||
Gradient Norm Attack. Computes p-norm of gradients w.r.t. input tokens. | ||
""" | ||
# We ignore probs here since they are computed in the general case without gradient-tracking (to save memory) | ||
|
||
# Hyper-params specific to min-k attack | ||
p: float = kwargs.get("p", np.inf) | ||
if p not in [1, 2, np.inf]: | ||
raise ValueError(f"Invalid p-norm value: {p}.") | ||
|
||
# Make sure model params require gradients | ||
# for name, param in self.target_model.model.named_parameters(): | ||
# param.requires_grad = True | ||
|
||
# Get gradients for model parameters | ||
self.target_model.model.zero_grad() | ||
all_prob = self.target_model.get_probabilities(document, tokens=tokens, no_grads=False) | ||
loss = - ch.mean(all_prob) | ||
loss.backward() | ||
|
||
# Compute p-norm of gradients (for all model params where grad exists) | ||
grad_norms = [] | ||
for param in self.target_model.model.parameters(): | ||
if param.grad is not None: | ||
grad_norms.append(param.grad.detach().norm(p)) | ||
grad_norm = ch.stack(grad_norms).mean() | ||
|
||
# Zero out gradients again | ||
self.target_model.model.zero_grad() | ||
|
||
return -grad_norm.cpu().numpy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.