From 576339bb6f75ded0d1a69a99b904ab5682609ac9 Mon Sep 17 00:00:00 2001 From: Anshuman Suri Date: Wed, 27 Mar 2024 12:49:16 -0400 Subject: [PATCH] Add gradnorm attack --- README.md | 5 ++ .../{blackbox_attacks.py => all_attacks.py} | 6 ++- mimir/attacks/gradnorm.py | 47 +++++++++++++++++++ mimir/attacks/loss.py | 2 +- mimir/attacks/min_k.py | 2 +- mimir/attacks/neighborhood.py | 2 +- mimir/attacks/quantile.py | 2 +- mimir/attacks/reference.py | 2 +- mimir/attacks/utils.py | 14 +++--- mimir/attacks/zlib.py | 2 +- mimir/models.py | 16 +++++-- notebooks/new_mi_experiment.py | 29 ++++++------ run.py | 30 ++++++------ tests/test_attacks.py | 4 +- 14 files changed, 112 insertions(+), 51 deletions(-) rename mimir/attacks/{blackbox_attacks.py => all_attacks.py} (93%) create mode 100644 mimir/attacks/gradnorm.py diff --git a/README.md b/README.md index 963f0db..e0ba3b0 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,11 @@ We include and implement the following attacks, as described in our paper. - [Zlib Entropy](https://www.usenix.org/system/files/sec21-carlini-extracting.pdf) (`zlib`). Uses the zlib compression size of a sample to approximate local difficulty of sample. - [Min-k% Prob](https://swj0419.github.io/detect-pretrain.github.io/) (`min_k`). Uses k% of tokens with minimum likelihood for score computation. - [Neighborhood](https://aclanthology.org/2023.findings-acl.719/) (`ne`). Generates neighbors using auxiliary model and measures change in likelihood. +- [Gradient Norm](https://arxiv.org/abs/2402.17012) (`gradnorm`). Uses gradient norm of the target datapoint as score. + +## Adding your own dataset + +To extend the package for your own dataset, you can directly load your data inside `load_cached()` in `data_utils.py`, or add an additional if-else within `load()` in `data_utils.py` if it cannot be loaded from memory (or some source) easily. We will probably add a more general way to do this in the future. ## Adding your own attack diff --git a/mimir/attacks/blackbox_attacks.py b/mimir/attacks/all_attacks.py similarity index 93% rename from mimir/attacks/blackbox_attacks.py rename to mimir/attacks/all_attacks.py index 72f9207..d864b6c 100644 --- a/mimir/attacks/blackbox_attacks.py +++ b/mimir/attacks/all_attacks.py @@ -7,22 +7,24 @@ # Attack definitions -class BlackBoxAttacks(str, Enum): +class AllAttacks(str, Enum): LOSS = "loss" # Done REFERENCE_BASED = "ref" # Done ZLIB = "zlib" # Done MIN_K = "min_k" # Done NEIGHBOR = "ne" # Done + GRADNORM = "gradnorm" # Done # QUANTILE = "quantile" # Uncomment when tested implementation is available # Base attack class class Attack: - def __init__(self, config, target_model: Model, ref_model: Model = None): + def __init__(self, config, target_model: Model, ref_model: Model = None, is_blackbox: bool = True): self.config = config self.target_model = target_model self.ref_model = ref_model self.is_loaded = False + self.is_blackbox = is_blackbox def load(self): """ diff --git a/mimir/attacks/gradnorm.py b/mimir/attacks/gradnorm.py new file mode 100644 index 0000000..e9e1c69 --- /dev/null +++ b/mimir/attacks/gradnorm.py @@ -0,0 +1,47 @@ +""" + Gradient-norm attack. Proposed for MIA in multiple settings, and particularly experimented for pre-training data and LLMs in https://arxiv.org/abs/2402.17012 +""" + +import torch as ch +import numpy as np +from mimir.attacks.all_attacks import Attack +from mimir.models import Model +from mimir.config import ExperimentConfig + + +class GradNormAttack(Attack): + def __init__(self, config: ExperimentConfig, model: Model): + super().__init__(config, model, ref_model=None, is_blackbox=False) + + def _attack(self, document, probs, tokens=None, **kwargs): + """ + Gradient Norm Attack. Computes p-norm of gradients w.r.t. input tokens. + """ + # We ignore probs here since they are computed in the general case without gradient-tracking (to save memory) + + # Hyper-params specific to min-k attack + p: float = kwargs.get("p", np.inf) + if p not in [1, 2, np.inf]: + raise ValueError(f"Invalid p-norm value: {p}.") + + # Make sure model params require gradients + # for name, param in self.target_model.model.named_parameters(): + # param.requires_grad = True + + # Get gradients for model parameters + self.target_model.model.zero_grad() + all_prob = self.target_model.get_probabilities(document, tokens=tokens, no_grads=False) + loss = - ch.mean(all_prob) + loss.backward() + + # Compute p-norm of gradients (for all model params where grad exists) + grad_norms = [] + for param in self.target_model.model.parameters(): + if param.grad is not None: + grad_norms.append(param.grad.detach().norm(p)) + grad_norm = ch.stack(grad_norms).mean() + + # Zero out gradients again + self.target_model.model.zero_grad() + + return -grad_norm.cpu().numpy() diff --git a/mimir/attacks/loss.py b/mimir/attacks/loss.py index c1bda89..15452b8 100644 --- a/mimir/attacks/loss.py +++ b/mimir/attacks/loss.py @@ -2,7 +2,7 @@ Straight-forward LOSS attack, as described in https://ieeexplore.ieee.org/abstract/document/8429311 """ import torch as ch -from mimir.attacks.blackbox_attacks import Attack +from mimir.attacks.all_attacks import Attack from mimir.models import Model from mimir.config import ExperimentConfig diff --git a/mimir/attacks/min_k.py b/mimir/attacks/min_k.py index 908e20c..2fa1c70 100644 --- a/mimir/attacks/min_k.py +++ b/mimir/attacks/min_k.py @@ -3,7 +3,7 @@ """ import torch as ch import numpy as np -from mimir.attacks.blackbox_attacks import Attack +from mimir.attacks.all_attacks import Attack from mimir.models import Model from mimir.config import ExperimentConfig diff --git a/mimir/attacks/neighborhood.py b/mimir/attacks/neighborhood.py index d7b9804..aed049b 100644 --- a/mimir/attacks/neighborhood.py +++ b/mimir/attacks/neighborhood.py @@ -14,7 +14,7 @@ from mimir.config import ExperimentConfig from mimir.attacks.attack_utils import count_masks, apply_extracted_fills from mimir.models import Model, ReferenceModel -from mimir.attacks.blackbox_attacks import Attack +from mimir.attacks.all_attacks import Attack class NeighborhoodAttack(Attack): diff --git a/mimir/attacks/quantile.py b/mimir/attacks/quantile.py index 5a1e67e..4e748dd 100644 --- a/mimir/attacks/quantile.py +++ b/mimir/attacks/quantile.py @@ -9,7 +9,7 @@ from transformers import TrainingArguments, Trainer from datasets import Dataset -from mimir.attacks.blackbox_attacks import Attack +from mimir.attacks.all_attacks import Attack class CustomTrainer(Trainer): diff --git a/mimir/attacks/reference.py b/mimir/attacks/reference.py index a568205..bd85888 100644 --- a/mimir/attacks/reference.py +++ b/mimir/attacks/reference.py @@ -1,7 +1,7 @@ """ Reference-based attacks. """ -from mimir.attacks.blackbox_attacks import Attack +from mimir.attacks.all_attacks import Attack from mimir.models import Model, ReferenceModel from mimir.config import ExperimentConfig diff --git a/mimir/attacks/utils.py b/mimir/attacks/utils.py index 1339e55..ac81dc6 100644 --- a/mimir/attacks/utils.py +++ b/mimir/attacks/utils.py @@ -1,20 +1,22 @@ -from mimir.attacks.blackbox_attacks import BlackBoxAttacks +from mimir.attacks.all_attacks import AllAttacks from mimir.attacks.loss import LOSSAttack from mimir.attacks.reference import ReferenceAttack from mimir.attacks.zlib import ZLIBAttack from mimir.attacks.min_k import MinKProbAttack from mimir.attacks.neighborhood import NeighborhoodAttack +from mimir.attacks.gradnorm import GradNormAttack # TODO Use decorators to link attack implementations with enum above def get_attacker(attack: str): mapping = { - BlackBoxAttacks.LOSS: LOSSAttack, - BlackBoxAttacks.REFERENCE_BASED: ReferenceAttack, - BlackBoxAttacks.ZLIB: ZLIBAttack, - BlackBoxAttacks.MIN_K: MinKProbAttack, - BlackBoxAttacks.NEIGHBOR: NeighborhoodAttack, + AllAttacks.LOSS: LOSSAttack, + AllAttacks.REFERENCE_BASED: ReferenceAttack, + AllAttacks.ZLIB: ZLIBAttack, + AllAttacks.MIN_K: MinKProbAttack, + AllAttacks.NEIGHBOR: NeighborhoodAttack, + AllAttacks.GRADNORM: GradNormAttack, } attack_cls = mapping.get(attack, None) if attack_cls is None: diff --git a/mimir/attacks/zlib.py b/mimir/attacks/zlib.py index 1624bef..ff85914 100644 --- a/mimir/attacks/zlib.py +++ b/mimir/attacks/zlib.py @@ -5,7 +5,7 @@ import torch as ch import zlib -from mimir.attacks.blackbox_attacks import Attack +from mimir.attacks.all_attacks import Attack from mimir.models import Model from mimir.config import ExperimentConfig diff --git a/mimir/models.py b/mimir/models.py index f43f848..4a8897e 100644 --- a/mimir/models.py +++ b/mimir/models.py @@ -107,10 +107,15 @@ def get_probabilities(self, target_ids = input_ids.clone() target_ids[:, :-trg_len] = -100 - logits = self.model(input_ids, labels=target_ids).logits.cpu() + logits = self.model(input_ids, labels=target_ids).logits + if no_grads: + logits = logits.cpu() shift_logits = logits[..., :-1, :].contiguous() probabilities = torch.nn.functional.log_softmax(shift_logits, dim=-1) - shift_labels = target_ids[..., 1:].cpu().contiguous() + shift_labels = target_ids[..., 1:] + if no_grads: + shift_labels = shift_labels.cpu() + shift_labels = shift_labels.contiguous() labels_processed = shift_labels[0] del input_ids @@ -125,9 +130,10 @@ def get_probabilities(self, # Should be equal to # of tokens - 1 to account for shift assert len(all_prob) == labels.size(1) - 1 - if no_grads: - return all_prob - return torch.tensor(all_prob) + if not no_grads: + all_prob = torch.stack(all_prob) + + return all_prob @torch.no_grad() def get_ll(self, diff --git a/notebooks/new_mi_experiment.py b/notebooks/new_mi_experiment.py index 5bb15d2..97a9727 100644 --- a/notebooks/new_mi_experiment.py +++ b/notebooks/new_mi_experiment.py @@ -29,7 +29,7 @@ import mimir.plot_utils as plot_utils from mimir.utils import fix_seed from mimir.models import LanguageModel, ReferenceModel -from mimir.attacks.blackbox_attacks import BlackBoxAttacks, Attack +from mimir.attacks.all_attacks import AllAttacks, Attack from mimir.attacks.neighborhood import T5Model, BertModel, NeighborhoodAttack from mimir.attacks.utils import get_attacker @@ -44,7 +44,7 @@ def get_attackers( ): # Look at all attacks, and attacks that we have implemented attacks = config.blackbox_attacks - implemented_blackbox_attacks = [a.value for a in BlackBoxAttacks] + implemented_blackbox_attacks = [a.value for a in AllAttacks] # check for unimplemented attacks runnable_attacks = [] for a in attacks: @@ -57,16 +57,16 @@ def get_attackers( # Initialize attackers attackers = {} for attack in attacks: - if attack != BlackBoxAttacks.REFERENCE_BASED: + if attack != AllAttacks.REFERENCE_BASED: attackers[attack] = get_attacker(attack)(config, target_model) # Initialize reference-based attackers if specified if ref_models is not None: for name, ref_model in ref_models.items(): - attacker = get_attacker(BlackBoxAttacks.REFERENCE_BASED)( + attacker = get_attacker(AllAttacks.REFERENCE_BASED)( config, target_model, ref_model ) - attackers[f"{BlackBoxAttacks.REFERENCE_BASED}-{name.split('/')[-1]}"] = attacker + attackers[f"{AllAttacks.REFERENCE_BASED}-{name.split('/')[-1]}"] = attacker return attackers @@ -93,7 +93,7 @@ def get_mia_scores( results = [] neighbors = None - if BlackBoxAttacks.NEIGHBOR in attackers_dict.keys() and neigh_config.load_from_cache: + if AllAttacks.NEIGHBOR in attackers_dict.keys() and neigh_config.load_from_cache: neighbors = data[f"neighbors"] print("Loaded neighbors from cache!") @@ -136,16 +136,16 @@ def get_mia_scores( detokenized_sample[i], tokens=substr, probs=s_tk_probs ) ) - sample_information[BlackBoxAttacks.LOSS].append(loss) + sample_information[AllAttacks.LOSS].append(loss) # TODO: Shift functionality into each attack entirely, so that this is just a for loop # For each attack for attack, attacker in attackers_dict.items(): # LOSS already added above, Reference handled later - if attack.startswith(BlackBoxAttacks.REFERENCE_BASED) or attack == BlackBoxAttacks.LOSS: + if attack.startswith(AllAttacks.REFERENCE_BASED) or attack == AllAttacks.LOSS: continue - if attack != BlackBoxAttacks.NEIGHBOR: + if attack != AllAttacks.NEIGHBOR: score = attacker.attack( substr, probs=s_tk_probs, @@ -191,7 +191,7 @@ def get_mia_scores( # Perform reference-based attacks if ref_models is not None: for name, _ in ref_models.items(): - ref_key = f"{BlackBoxAttacks.REFERENCE_BASED}-{name.split('/')[-1]}" + ref_key = f"{AllAttacks.REFERENCE_BASED}-{name.split('/')[-1]}" attacker = attackers_dict.get(ref_key, None) if attacker is None: continue @@ -202,8 +202,7 @@ def get_mia_scores( for i, s in enumerate(r["sample"]): if config.pretokenized: s = r["detokenized"][i] - score = attacker.attack(s, probs=None, - loss=r[BlackBoxAttacks.LOSS][i]) + score = attacker.attack(s, probs=None, loss=r[AllAttacks.LOSS][i]) ref_model_scores.append(score) r[ref_key].extend(ref_model_scores) @@ -275,7 +274,7 @@ def generate_data( ref_models = None if ( ref_config is not None - and BlackBoxAttacks.REFERENCE_BASED in config.blackbox_attacks + and AllAttacks.REFERENCE_BASED in config.blackbox_attacks ): ref_models = { model: ReferenceModel(config, model) for model in ref_config.models @@ -289,9 +288,9 @@ def generate_data( if ( neigh_config and (not neigh_config.load_from_cache) - and (BlackBoxAttacks.NEIGHBOR in config.blackbox_attacks) + and (AllAttacks.NEIGHBOR in config.blackbox_attacks) ): - attacker_ne = attackers_dict[BlackBoxAttacks.NEIGHBOR] + attacker_ne = attackers_dict[AllAttacks.NEIGHBOR] mask_model = attacker_ne.get_mask_model() print("MOVING BASE MODEL TO GPU...", end="", flush=True) diff --git a/run.py b/run.py index ec9f388..b10fb1a 100644 --- a/run.py +++ b/run.py @@ -25,7 +25,7 @@ import mimir.plot_utils as plot_utils from mimir.utils import fix_seed from mimir.models import LanguageModel, ReferenceModel, OpenAI_APIModel -from mimir.attacks.blackbox_attacks import BlackBoxAttacks, Attack +from mimir.attacks.all_attacks import AllAttacks, Attack from mimir.attacks.utils import get_attacker from mimir.attacks.attack_utils import ( get_roc_metrics, @@ -41,7 +41,7 @@ def get_attackers( ): # Look at all attacks, and attacks that we have implemented attacks = config.blackbox_attacks - implemented_blackbox_attacks = [a.value for a in BlackBoxAttacks] + implemented_blackbox_attacks = [a.value for a in AllAttacks] # check for unimplemented attacks runnable_attacks = [] for a in attacks: @@ -54,16 +54,16 @@ def get_attackers( # Initialize attackers attackers = {} for attack in attacks: - if attack != BlackBoxAttacks.REFERENCE_BASED: + if attack != AllAttacks.REFERENCE_BASED: attackers[attack] = get_attacker(attack)(config, target_model) # Initialize reference-based attackers if specified if ref_models is not None: for name, ref_model in ref_models.items(): - attacker = get_attacker(BlackBoxAttacks.REFERENCE_BASED)( + attacker = get_attacker(AllAttacks.REFERENCE_BASED)( config, target_model, ref_model ) - attackers[f"{BlackBoxAttacks.REFERENCE_BASED}-{name.split('/')[-1]}"] = attacker + attackers[f"{AllAttacks.REFERENCE_BASED}-{name.split('/')[-1]}"] = attacker return attackers @@ -92,7 +92,7 @@ def get_mia_scores( results = [] neighbors = None - if BlackBoxAttacks.NEIGHBOR in attackers_dict.keys() and neigh_config.load_from_cache: + if AllAttacks.NEIGHBOR in attackers_dict.keys() and neigh_config.load_from_cache: neighbors = data[f"neighbors"] print("Loaded neighbors from cache!") @@ -140,16 +140,16 @@ def get_mia_scores( detokenized_sample[i], tokens=substr, probs=s_tk_probs ) ) - sample_information[BlackBoxAttacks.LOSS].append(loss) + sample_information[AllAttacks.LOSS].append(loss) # TODO: Shift functionality into each attack entirely, so that this is just a for loop # For each attack for attack, attacker in attackers_dict.items(): # LOSS already added above, Reference handled later - if attack.startswith(BlackBoxAttacks.REFERENCE_BASED) or attack == BlackBoxAttacks.LOSS: + if attack.startswith(AllAttacks.REFERENCE_BASED) or attack == AllAttacks.LOSS: continue - if attack != BlackBoxAttacks.NEIGHBOR: + if attack != AllAttacks.NEIGHBOR: score = attacker.attack( substr, probs=s_tk_probs, @@ -229,7 +229,7 @@ def get_mia_scores( # Perform reference-based attacks if ref_models is not None: for name, ref_model in ref_models.items(): - ref_key = f"{BlackBoxAttacks.REFERENCE_BASED}-{name.split('/')[-1]}" + ref_key = f"{AllAttacks.REFERENCE_BASED}-{name.split('/')[-1]}" attacker = attackers_dict.get(ref_key, None) if attacker is None: continue @@ -241,7 +241,7 @@ def get_mia_scores( if config.pretokenized: s = r["detokenized"][i] score = attacker.attack(s, probs=None, - loss=r[BlackBoxAttacks.LOSS][i]) + loss=r[AllAttacks.LOSS][i]) ref_model_scores.append(score) r[ref_key].extend(ref_model_scores) @@ -477,7 +477,7 @@ def main(config: ExperimentConfig): ref_models = None if ( ref_config is not None - and BlackBoxAttacks.REFERENCE_BASED in config.blackbox_attacks + and AllAttacks.REFERENCE_BASED in config.blackbox_attacks ): ref_models = { model: ReferenceModel(config, model) for model in ref_config.models @@ -491,9 +491,9 @@ def main(config: ExperimentConfig): if ( neigh_config and (not neigh_config.load_from_cache) - and (BlackBoxAttacks.NEIGHBOR in config.blackbox_attacks) + and (AllAttacks.NEIGHBOR in config.blackbox_attacks) ): - attacker_ne = attackers_dict[BlackBoxAttacks.NEIGHBOR] + attacker_ne = attackers_dict[AllAttacks.NEIGHBOR] mask_model = attacker_ne.get_mask_model() print("MOVING BASE MODEL TO GPU...", end="", flush=True) @@ -549,7 +549,7 @@ def main(config: ExperimentConfig): # If neighborhood attack is used, see if we have cache available (and load from it, if we do) neighbors_nonmember, neighbors_member = None, None if ( - BlackBoxAttacks.NEIGHBOR in config.blackbox_attacks + AllAttacks.NEIGHBOR in config.blackbox_attacks and neigh_config.load_from_cache ): neighbors_nonmember, neighbors_member = {}, {} diff --git a/tests/test_attacks.py b/tests/test_attacks.py index 949e6e9..3d6be5b 100644 --- a/tests/test_attacks.py +++ b/tests/test_attacks.py @@ -6,7 +6,7 @@ import numpy as np import torch.nn as nn -from mimir.attacks.blackbox_attacks import BlackBoxAttacks +from mimir.attacks.all_attacks import AllAttacks from mimir.attacks.utils import get_attacker @@ -16,7 +16,7 @@ def test_attacks_exist(self): Check if all known attacks can be loaded. """ # Enumerate all "available" attacks and make sure they are available - for attack in BlackBoxAttacks: + for attack in AllAttacks: attacker = get_attacker(attack) assert attacker is not None, f"Attack {attack} not found" # TODO: Use a 'testing' config and model to check if the attack can be loaded