diff --git a/README.md b/README.md index e0ba3b0..a684518 100644 --- a/README.md +++ b/README.md @@ -51,8 +51,9 @@ We include and implement the following attacks, as described in our paper. - [Likelihood](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8429311) (`loss`). Works by simply using the likelihood of the target datapoint as score. - [Reference-based](https://arxiv.org/abs/2004.15011) (`ref`). Normalizes likelihood score with score obtained from a reference model. - [Zlib Entropy](https://www.usenix.org/system/files/sec21-carlini-extracting.pdf) (`zlib`). Uses the zlib compression size of a sample to approximate local difficulty of sample. -- [Min-k% Prob](https://swj0419.github.io/detect-pretrain.github.io/) (`min_k`). Uses k% of tokens with minimum likelihood for score computation. - [Neighborhood](https://aclanthology.org/2023.findings-acl.719/) (`ne`). Generates neighbors using auxiliary model and measures change in likelihood. +- [Min-K% Prob](https://swj0419.github.io/detect-pretrain.github.io/) (`min_k`). Uses k% of tokens with minimum likelihood for score computation. +- [Min-K%++](https://zjysteven.github.io/mink-plus-plus/) (`min_k++`). Uses k% of tokens with minimum *normalized* likelihood for score computation. - [Gradient Norm](https://arxiv.org/abs/2402.17012) (`gradnorm`). Uses gradient norm of the target datapoint as score. ## Adding your own dataset diff --git a/mimir/attacks/all_attacks.py b/mimir/attacks/all_attacks.py index d864b6c..21fc613 100644 --- a/mimir/attacks/all_attacks.py +++ b/mimir/attacks/all_attacks.py @@ -12,6 +12,7 @@ class AllAttacks(str, Enum): REFERENCE_BASED = "ref" # Done ZLIB = "zlib" # Done MIN_K = "min_k" # Done + MIN_K_PLUS_PLUS = "min_k++" # Done NEIGHBOR = "ne" # Done GRADNORM = "gradnorm" # Done # QUANTILE = "quantile" # Uncomment when tested implementation is available diff --git a/mimir/attacks/min_k_plus_plus.py b/mimir/attacks/min_k_plus_plus.py new file mode 100644 index 0000000..7817edc --- /dev/null +++ b/mimir/attacks/min_k_plus_plus.py @@ -0,0 +1,37 @@ +""" + Min-K%++ Attack: https://github.com/zjysteven/mink-plus-plus +""" +import torch as ch +import numpy as np +from mimir.attacks.all_attacks import Attack +from mimir.models import Model +from mimir.config import ExperimentConfig + + +class MinKPlusPlusAttack(Attack): + + def __init__(self, config: ExperimentConfig, model: Model): + super().__init__(config, model, ref_model=None) + + @ch.no_grad() + def _attack(self, document, probs, tokens=None, **kwargs): + """ + Min-K%++ Attack. + Gets token probabilties, normalize with the mean and std over the whole categorical distribution, + and returns normalized likelihood when computed over top k% of ngrams. + """ + # Hyper-params specific to min-k attack + k: float = kwargs.get("k", 0.2) + all_probs = kwargs.get("all_probs", None) + + target_prob, all_prob = ( + (probs, all_probs) + if (probs is not None and all_probs is not None) + else self.model.get_probabilities(document, tokens=tokens, return_all_probs=True) + ) + + mu = (all_prob['prob'] * all_prob['log_prob']).sum(-1) + sigma = (all_prob['prob'] * ch.square(all_prob['log_prob'])).sum(-1) - ch.square(mu) + scores = (np.array(target_prob) - mu.numpy()) / sigma.sqrt().numpy() + + return -np.mean(sorted(scores)[:int(len(scores) * k)]) \ No newline at end of file diff --git a/mimir/attacks/utils.py b/mimir/attacks/utils.py index ac81dc6..05b66cc 100644 --- a/mimir/attacks/utils.py +++ b/mimir/attacks/utils.py @@ -4,6 +4,7 @@ from mimir.attacks.reference import ReferenceAttack from mimir.attacks.zlib import ZLIBAttack from mimir.attacks.min_k import MinKProbAttack +from mimir.attacks.min_k_plus_plus import MinKPlusPlusAttack from mimir.attacks.neighborhood import NeighborhoodAttack from mimir.attacks.gradnorm import GradNormAttack @@ -15,6 +16,7 @@ def get_attacker(attack: str): AllAttacks.REFERENCE_BASED: ReferenceAttack, AllAttacks.ZLIB: ZLIBAttack, AllAttacks.MIN_K: MinKProbAttack, + AllAttacks.MIN_K_PLUS_PLUS: MinKPlusPlusAttack, AllAttacks.NEIGHBOR: NeighborhoodAttack, AllAttacks.GRADNORM: GradNormAttack, } diff --git a/mimir/models.py b/mimir/models.py index 4a8897e..3e25bb5 100644 --- a/mimir/models.py +++ b/mimir/models.py @@ -70,7 +70,8 @@ def unload(self): def get_probabilities(self, text: str, tokens: np.ndarray = None, - no_grads: bool = True): + no_grads: bool = True, + return_all_probs: bool = False): """ Get the probabilities or log-softmaxed logits for a text under the current model. Args: @@ -98,7 +99,11 @@ def get_probabilities(self, text, return_tensors="pt") labels = tokenized.input_ids - all_prob = [] + target_token_log_prob = [] + all_token_prob = { + 'log_prob': [], + 'prob': [] + } for i in range(0, labels.size(1), self.stride): begin_loc = max(i + self.stride - self.max_length, 0) end_loc = min(i + self.stride, labels.size(1)) @@ -111,7 +116,8 @@ def get_probabilities(self, if no_grads: logits = logits.cpu() shift_logits = logits[..., :-1, :].contiguous() - probabilities = torch.nn.functional.log_softmax(shift_logits, dim=-1) + probabilities = torch.nn.functional.softmax(shift_logits, dim=-1) + log_probabilities = torch.nn.functional.log_softmax(shift_logits, dim=-1) shift_labels = target_ids[..., 1:] if no_grads: shift_labels = shift_labels.cpu() @@ -123,17 +129,26 @@ def get_probabilities(self, for i, token_id in enumerate(labels_processed): if token_id != -100: - probability = probabilities[0, i, token_id] + log_probability = log_probabilities[0, i, token_id] if no_grads: - probability = probability.item() - all_prob.append(probability) + log_probability = log_probability.item() + target_token_log_prob.append(log_probability) + all_token_prob['log_prob'].append(log_probabilities[0, i]) + all_token_prob["prob"].append(probabilities[0, i]) + # Should be equal to # of tokens - 1 to account for shift - assert len(all_prob) == labels.size(1) - 1 + assert len(target_token_log_prob) == labels.size(1) - 1 + all_token_prob['log_prob'] = torch.stack(all_token_prob['log_prob'], dim=0) + all_token_prob['prob'] = torch.stack(all_token_prob['prob'], dim=0) + assert len(target_token_log_prob) == len(all_token_prob['log_prob']) + assert len(target_token_log_prob) == len(all_token_prob['prob']) if not no_grads: - all_prob = torch.stack(all_prob) + target_token_log_prob = torch.stack(target_token_log_prob) - return all_prob + if not return_all_probs: + return target_token_log_prob + return target_token_log_prob, all_token_prob @torch.no_grad() def get_ll(self, diff --git a/run.py b/run.py index b10fb1a..4dddfa8 100644 --- a/run.py +++ b/run.py @@ -124,11 +124,11 @@ def get_mia_scores( neighbors_within = {n_perturbation: [] for n_perturbation in n_perturbation_list} for i, substr in enumerate(sample): # compute token probabilities for sample - s_tk_probs = ( - target_model.get_probabilities(substr) + s_tk_probs, s_all_probs = ( + target_model.get_probabilities(substr, return_all_probs=True) if not config.pretokenized else target_model.get_probabilities( - detokenized_sample[i], tokens=substr + detokenized_sample[i], tokens=substr, return_all_probs=True ) ) @@ -150,17 +150,38 @@ def get_mia_scores( continue if attack != AllAttacks.NEIGHBOR: - score = attacker.attack( - substr, - probs=s_tk_probs, - detokenized_sample=( - detokenized_sample[i] - if config.pretokenized - else None - ), - loss=loss, - ) - sample_information[attack].append(score) + if attack in [AllAttacks.MIN_K, AllAttacks.MIN_K_PLUS_PLUS]: + # For Min-K and Min-K++ + # iterate over k + for k in [ + 0.1, 0.2, 0.3, 0.4, 0.5, + 0.6, 0.7, 0.8, 0.9, 1.0 + ]: + score = attacker.attack( + substr, + probs=s_tk_probs, + detokenized_sample=( + detokenized_sample[i] + if config.pretokenized + else None + ), + loss=loss, + all_probs=s_all_probs, # for Min-K%++, + k=k + ) + sample_information[f"{attack}_{k}"].append(score) + else: + score = attacker.attack( + substr, + probs=s_tk_probs, + detokenized_sample=( + detokenized_sample[i] + if config.pretokenized + else None + ), + loss=loss, + ) + sample_information[attack].append(score) else: # For each 'number of neighbors' for n_perturbation in n_perturbation_list: