Skip to content

Commit

Permalink
Add gradnorm attack
Browse files Browse the repository at this point in the history
  • Loading branch information
iamgroot42 committed Mar 27, 2024
1 parent 13a77a4 commit 576339b
Show file tree
Hide file tree
Showing 14 changed files with 112 additions and 51 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ We include and implement the following attacks, as described in our paper.
- [Zlib Entropy](https://www.usenix.org/system/files/sec21-carlini-extracting.pdf) (`zlib`). Uses the zlib compression size of a sample to approximate local difficulty of sample.
- [Min-k% Prob](https://swj0419.github.io/detect-pretrain.github.io/) (`min_k`). Uses k% of tokens with minimum likelihood for score computation.
- [Neighborhood](https://aclanthology.org/2023.findings-acl.719/) (`ne`). Generates neighbors using auxiliary model and measures change in likelihood.
- [Gradient Norm](https://arxiv.org/abs/2402.17012) (`gradnorm`). Uses gradient norm of the target datapoint as score.

## Adding your own dataset

To extend the package for your own dataset, you can directly load your data inside `load_cached()` in `data_utils.py`, or add an additional if-else within `load()` in `data_utils.py` if it cannot be loaded from memory (or some source) easily. We will probably add a more general way to do this in the future.

## Adding your own attack

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,24 @@


# Attack definitions
class BlackBoxAttacks(str, Enum):
class AllAttacks(str, Enum):
LOSS = "loss" # Done
REFERENCE_BASED = "ref" # Done
ZLIB = "zlib" # Done
MIN_K = "min_k" # Done
NEIGHBOR = "ne" # Done
GRADNORM = "gradnorm" # Done
# QUANTILE = "quantile" # Uncomment when tested implementation is available


# Base attack class
class Attack:
def __init__(self, config, target_model: Model, ref_model: Model = None):
def __init__(self, config, target_model: Model, ref_model: Model = None, is_blackbox: bool = True):
self.config = config
self.target_model = target_model
self.ref_model = ref_model
self.is_loaded = False
self.is_blackbox = is_blackbox

def load(self):
"""
Expand Down
47 changes: 47 additions & 0 deletions mimir/attacks/gradnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Gradient-norm attack. Proposed for MIA in multiple settings, and particularly experimented for pre-training data and LLMs in https://arxiv.org/abs/2402.17012
"""

import torch as ch
import numpy as np
from mimir.attacks.all_attacks import Attack
from mimir.models import Model
from mimir.config import ExperimentConfig


class GradNormAttack(Attack):
def __init__(self, config: ExperimentConfig, model: Model):
super().__init__(config, model, ref_model=None, is_blackbox=False)

def _attack(self, document, probs, tokens=None, **kwargs):
"""
Gradient Norm Attack. Computes p-norm of gradients w.r.t. input tokens.
"""
# We ignore probs here since they are computed in the general case without gradient-tracking (to save memory)

# Hyper-params specific to min-k attack
p: float = kwargs.get("p", np.inf)
if p not in [1, 2, np.inf]:
raise ValueError(f"Invalid p-norm value: {p}.")

# Make sure model params require gradients
# for name, param in self.target_model.model.named_parameters():
# param.requires_grad = True

# Get gradients for model parameters
self.target_model.model.zero_grad()
all_prob = self.target_model.get_probabilities(document, tokens=tokens, no_grads=False)
loss = - ch.mean(all_prob)
loss.backward()

# Compute p-norm of gradients (for all model params where grad exists)
grad_norms = []
for param in self.target_model.model.parameters():
if param.grad is not None:
grad_norms.append(param.grad.detach().norm(p))
grad_norm = ch.stack(grad_norms).mean()

# Zero out gradients again
self.target_model.model.zero_grad()

return -grad_norm.cpu().numpy()
2 changes: 1 addition & 1 deletion mimir/attacks/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Straight-forward LOSS attack, as described in https://ieeexplore.ieee.org/abstract/document/8429311
"""
import torch as ch
from mimir.attacks.blackbox_attacks import Attack
from mimir.attacks.all_attacks import Attack
from mimir.models import Model
from mimir.config import ExperimentConfig

Expand Down
2 changes: 1 addition & 1 deletion mimir/attacks/min_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import torch as ch
import numpy as np
from mimir.attacks.blackbox_attacks import Attack
from mimir.attacks.all_attacks import Attack
from mimir.models import Model
from mimir.config import ExperimentConfig

Expand Down
2 changes: 1 addition & 1 deletion mimir/attacks/neighborhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mimir.config import ExperimentConfig
from mimir.attacks.attack_utils import count_masks, apply_extracted_fills
from mimir.models import Model, ReferenceModel
from mimir.attacks.blackbox_attacks import Attack
from mimir.attacks.all_attacks import Attack


class NeighborhoodAttack(Attack):
Expand Down
2 changes: 1 addition & 1 deletion mimir/attacks/quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers import TrainingArguments, Trainer
from datasets import Dataset

from mimir.attacks.blackbox_attacks import Attack
from mimir.attacks.all_attacks import Attack


class CustomTrainer(Trainer):
Expand Down
2 changes: 1 addition & 1 deletion mimir/attacks/reference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Reference-based attacks.
"""
from mimir.attacks.blackbox_attacks import Attack
from mimir.attacks.all_attacks import Attack
from mimir.models import Model, ReferenceModel
from mimir.config import ExperimentConfig

Expand Down
14 changes: 8 additions & 6 deletions mimir/attacks/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from mimir.attacks.blackbox_attacks import BlackBoxAttacks
from mimir.attacks.all_attacks import AllAttacks

from mimir.attacks.loss import LOSSAttack
from mimir.attacks.reference import ReferenceAttack
from mimir.attacks.zlib import ZLIBAttack
from mimir.attacks.min_k import MinKProbAttack
from mimir.attacks.neighborhood import NeighborhoodAttack
from mimir.attacks.gradnorm import GradNormAttack


# TODO Use decorators to link attack implementations with enum above
def get_attacker(attack: str):
mapping = {
BlackBoxAttacks.LOSS: LOSSAttack,
BlackBoxAttacks.REFERENCE_BASED: ReferenceAttack,
BlackBoxAttacks.ZLIB: ZLIBAttack,
BlackBoxAttacks.MIN_K: MinKProbAttack,
BlackBoxAttacks.NEIGHBOR: NeighborhoodAttack,
AllAttacks.LOSS: LOSSAttack,
AllAttacks.REFERENCE_BASED: ReferenceAttack,
AllAttacks.ZLIB: ZLIBAttack,
AllAttacks.MIN_K: MinKProbAttack,
AllAttacks.NEIGHBOR: NeighborhoodAttack,
AllAttacks.GRADNORM: GradNormAttack,
}
attack_cls = mapping.get(attack, None)
if attack_cls is None:
Expand Down
2 changes: 1 addition & 1 deletion mimir/attacks/zlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch as ch
import zlib

from mimir.attacks.blackbox_attacks import Attack
from mimir.attacks.all_attacks import Attack
from mimir.models import Model
from mimir.config import ExperimentConfig

Expand Down
16 changes: 11 additions & 5 deletions mimir/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,15 @@ def get_probabilities(self,
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100

logits = self.model(input_ids, labels=target_ids).logits.cpu()
logits = self.model(input_ids, labels=target_ids).logits
if no_grads:
logits = logits.cpu()
shift_logits = logits[..., :-1, :].contiguous()
probabilities = torch.nn.functional.log_softmax(shift_logits, dim=-1)
shift_labels = target_ids[..., 1:].cpu().contiguous()
shift_labels = target_ids[..., 1:]
if no_grads:
shift_labels = shift_labels.cpu()
shift_labels = shift_labels.contiguous()
labels_processed = shift_labels[0]

del input_ids
Expand All @@ -125,9 +130,10 @@ def get_probabilities(self,
# Should be equal to # of tokens - 1 to account for shift
assert len(all_prob) == labels.size(1) - 1

if no_grads:
return all_prob
return torch.tensor(all_prob)
if not no_grads:
all_prob = torch.stack(all_prob)

return all_prob

@torch.no_grad()
def get_ll(self,
Expand Down
29 changes: 14 additions & 15 deletions notebooks/new_mi_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import mimir.plot_utils as plot_utils
from mimir.utils import fix_seed
from mimir.models import LanguageModel, ReferenceModel
from mimir.attacks.blackbox_attacks import BlackBoxAttacks, Attack
from mimir.attacks.all_attacks import AllAttacks, Attack
from mimir.attacks.neighborhood import T5Model, BertModel, NeighborhoodAttack
from mimir.attacks.utils import get_attacker

Expand All @@ -44,7 +44,7 @@ def get_attackers(
):
# Look at all attacks, and attacks that we have implemented
attacks = config.blackbox_attacks
implemented_blackbox_attacks = [a.value for a in BlackBoxAttacks]
implemented_blackbox_attacks = [a.value for a in AllAttacks]
# check for unimplemented attacks
runnable_attacks = []
for a in attacks:
Expand All @@ -57,16 +57,16 @@ def get_attackers(
# Initialize attackers
attackers = {}
for attack in attacks:
if attack != BlackBoxAttacks.REFERENCE_BASED:
if attack != AllAttacks.REFERENCE_BASED:
attackers[attack] = get_attacker(attack)(config, target_model)

# Initialize reference-based attackers if specified
if ref_models is not None:
for name, ref_model in ref_models.items():
attacker = get_attacker(BlackBoxAttacks.REFERENCE_BASED)(
attacker = get_attacker(AllAttacks.REFERENCE_BASED)(
config, target_model, ref_model
)
attackers[f"{BlackBoxAttacks.REFERENCE_BASED}-{name.split('/')[-1]}"] = attacker
attackers[f"{AllAttacks.REFERENCE_BASED}-{name.split('/')[-1]}"] = attacker
return attackers


Expand All @@ -93,7 +93,7 @@ def get_mia_scores(

results = []
neighbors = None
if BlackBoxAttacks.NEIGHBOR in attackers_dict.keys() and neigh_config.load_from_cache:
if AllAttacks.NEIGHBOR in attackers_dict.keys() and neigh_config.load_from_cache:
neighbors = data[f"neighbors"]
print("Loaded neighbors from cache!")

Expand Down Expand Up @@ -136,16 +136,16 @@ def get_mia_scores(
detokenized_sample[i], tokens=substr, probs=s_tk_probs
)
)
sample_information[BlackBoxAttacks.LOSS].append(loss)
sample_information[AllAttacks.LOSS].append(loss)

# TODO: Shift functionality into each attack entirely, so that this is just a for loop
# For each attack
for attack, attacker in attackers_dict.items():
# LOSS already added above, Reference handled later
if attack.startswith(BlackBoxAttacks.REFERENCE_BASED) or attack == BlackBoxAttacks.LOSS:
if attack.startswith(AllAttacks.REFERENCE_BASED) or attack == AllAttacks.LOSS:
continue

if attack != BlackBoxAttacks.NEIGHBOR:
if attack != AllAttacks.NEIGHBOR:
score = attacker.attack(
substr,
probs=s_tk_probs,
Expand Down Expand Up @@ -191,7 +191,7 @@ def get_mia_scores(
# Perform reference-based attacks
if ref_models is not None:
for name, _ in ref_models.items():
ref_key = f"{BlackBoxAttacks.REFERENCE_BASED}-{name.split('/')[-1]}"
ref_key = f"{AllAttacks.REFERENCE_BASED}-{name.split('/')[-1]}"
attacker = attackers_dict.get(ref_key, None)
if attacker is None:
continue
Expand All @@ -202,8 +202,7 @@ def get_mia_scores(
for i, s in enumerate(r["sample"]):
if config.pretokenized:
s = r["detokenized"][i]
score = attacker.attack(s, probs=None,
loss=r[BlackBoxAttacks.LOSS][i])
score = attacker.attack(s, probs=None, loss=r[AllAttacks.LOSS][i])
ref_model_scores.append(score)
r[ref_key].extend(ref_model_scores)

Expand Down Expand Up @@ -275,7 +274,7 @@ def generate_data(
ref_models = None
if (
ref_config is not None
and BlackBoxAttacks.REFERENCE_BASED in config.blackbox_attacks
and AllAttacks.REFERENCE_BASED in config.blackbox_attacks
):
ref_models = {
model: ReferenceModel(config, model) for model in ref_config.models
Expand All @@ -289,9 +288,9 @@ def generate_data(
if (
neigh_config
and (not neigh_config.load_from_cache)
and (BlackBoxAttacks.NEIGHBOR in config.blackbox_attacks)
and (AllAttacks.NEIGHBOR in config.blackbox_attacks)
):
attacker_ne = attackers_dict[BlackBoxAttacks.NEIGHBOR]
attacker_ne = attackers_dict[AllAttacks.NEIGHBOR]
mask_model = attacker_ne.get_mask_model()

print("MOVING BASE MODEL TO GPU...", end="", flush=True)
Expand Down
Loading

0 comments on commit 576339b

Please sign in to comment.