Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
zjysteven committed Apr 4, 2024
1 parent 79760db commit 206c010
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 24 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ We include and implement the following attacks, as described in our paper.
- [Likelihood](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8429311) (`loss`). Works by simply using the likelihood of the target datapoint as score.
- [Reference-based](https://arxiv.org/abs/2004.15011) (`ref`). Normalizes likelihood score with score obtained from a reference model.
- [Zlib Entropy](https://www.usenix.org/system/files/sec21-carlini-extracting.pdf) (`zlib`). Uses the zlib compression size of a sample to approximate local difficulty of sample.
- [Min-k% Prob](https://swj0419.github.io/detect-pretrain.github.io/) (`min_k`). Uses k% of tokens with minimum likelihood for score computation.
- [Neighborhood](https://aclanthology.org/2023.findings-acl.719/) (`ne`). Generates neighbors using auxiliary model and measures change in likelihood.
- [Min-K% Prob](https://swj0419.github.io/detect-pretrain.github.io/) (`min_k`). Uses k% of tokens with minimum likelihood for score computation.
- [Min-K%++](https://zjysteven.github.io/mink-plus-plus/) (`min_k++`). Uses k% of tokens with minimum *normalized* likelihood for score computation.
- [Gradient Norm](https://arxiv.org/abs/2402.17012) (`gradnorm`). Uses gradient norm of the target datapoint as score.

## Adding your own dataset
Expand Down
1 change: 1 addition & 0 deletions mimir/attacks/all_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class AllAttacks(str, Enum):
REFERENCE_BASED = "ref" # Done
ZLIB = "zlib" # Done
MIN_K = "min_k" # Done
MIN_K_PLUS_PLUS = "min_k++" # Done
NEIGHBOR = "ne" # Done
GRADNORM = "gradnorm" # Done
# QUANTILE = "quantile" # Uncomment when tested implementation is available
Expand Down
37 changes: 37 additions & 0 deletions mimir/attacks/min_k_plus_plus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Min-K%++ Attack: https://github.com/zjysteven/mink-plus-plus
"""
import torch as ch
import numpy as np
from mimir.attacks.all_attacks import Attack
from mimir.models import Model
from mimir.config import ExperimentConfig


class MinKPlusPlusAttack(Attack):

def __init__(self, config: ExperimentConfig, model: Model):
super().__init__(config, model, ref_model=None)

@ch.no_grad()
def _attack(self, document, probs, tokens=None, **kwargs):
"""
Min-K%++ Attack.
Gets token probabilties, normalize with the mean and std over the whole categorical distribution,
and returns normalized likelihood when computed over top k% of ngrams.
"""
# Hyper-params specific to min-k attack
k: float = kwargs.get("k", 0.2)
all_probs = kwargs.get("all_probs", None)

target_prob, all_prob = (
(probs, all_probs)
if (probs is not None and all_probs is not None)
else self.model.get_probabilities(document, tokens=tokens, return_all_probs=True)
)

mu = (all_prob['prob'] * all_prob['log_prob']).sum(-1)
sigma = (all_prob['prob'] * ch.square(all_prob['log_prob'])).sum(-1) - ch.square(mu)
scores = (np.array(target_prob) - mu.numpy()) / sigma.sqrt().numpy()

return -np.mean(sorted(scores)[:int(len(scores) * k)])
2 changes: 2 additions & 0 deletions mimir/attacks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mimir.attacks.reference import ReferenceAttack
from mimir.attacks.zlib import ZLIBAttack
from mimir.attacks.min_k import MinKProbAttack
from mimir.attacks.min_k_plus_plus import MinKPlusPlusAttack
from mimir.attacks.neighborhood import NeighborhoodAttack
from mimir.attacks.gradnorm import GradNormAttack

Expand All @@ -15,6 +16,7 @@ def get_attacker(attack: str):
AllAttacks.REFERENCE_BASED: ReferenceAttack,
AllAttacks.ZLIB: ZLIBAttack,
AllAttacks.MIN_K: MinKProbAttack,
AllAttacks.MIN_K_PLUS_PLUS: MinKPlusPlusAttack,
AllAttacks.NEIGHBOR: NeighborhoodAttack,
AllAttacks.GRADNORM: GradNormAttack,
}
Expand Down
33 changes: 24 additions & 9 deletions mimir/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def unload(self):
def get_probabilities(self,
text: str,
tokens: np.ndarray = None,
no_grads: bool = True):
no_grads: bool = True,
return_all_probs: bool = False):
"""
Get the probabilities or log-softmaxed logits for a text under the current model.
Args:
Expand Down Expand Up @@ -98,7 +99,11 @@ def get_probabilities(self,
text, return_tensors="pt")
labels = tokenized.input_ids

all_prob = []
target_token_log_prob = []
all_token_prob = {
'log_prob': [],
'prob': []
}
for i in range(0, labels.size(1), self.stride):
begin_loc = max(i + self.stride - self.max_length, 0)
end_loc = min(i + self.stride, labels.size(1))
Expand All @@ -111,7 +116,8 @@ def get_probabilities(self,
if no_grads:
logits = logits.cpu()
shift_logits = logits[..., :-1, :].contiguous()
probabilities = torch.nn.functional.log_softmax(shift_logits, dim=-1)
probabilities = torch.nn.functional.softmax(shift_logits, dim=-1)
log_probabilities = torch.nn.functional.log_softmax(shift_logits, dim=-1)
shift_labels = target_ids[..., 1:]
if no_grads:
shift_labels = shift_labels.cpu()
Expand All @@ -123,17 +129,26 @@ def get_probabilities(self,

for i, token_id in enumerate(labels_processed):
if token_id != -100:
probability = probabilities[0, i, token_id]
log_probability = log_probabilities[0, i, token_id]
if no_grads:
probability = probability.item()
all_prob.append(probability)
log_probability = log_probability.item()
target_token_log_prob.append(log_probability)
all_token_prob['log_prob'].append(log_probabilities[0, i])
all_token_prob["prob"].append(probabilities[0, i])

# Should be equal to # of tokens - 1 to account for shift
assert len(all_prob) == labels.size(1) - 1
assert len(target_token_log_prob) == labels.size(1) - 1
all_token_prob['log_prob'] = torch.stack(all_token_prob['log_prob'], dim=0)
all_token_prob['prob'] = torch.stack(all_token_prob['prob'], dim=0)
assert len(target_token_log_prob) == len(all_token_prob['log_prob'])
assert len(target_token_log_prob) == len(all_token_prob['prob'])

if not no_grads:
all_prob = torch.stack(all_prob)
target_token_log_prob = torch.stack(target_token_log_prob)

return all_prob
if not return_all_probs:
return target_token_log_prob
return target_token_log_prob, all_token_prob

@torch.no_grad()
def get_ll(self,
Expand Down
49 changes: 35 additions & 14 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ def get_mia_scores(
neighbors_within = {n_perturbation: [] for n_perturbation in n_perturbation_list}
for i, substr in enumerate(sample):
# compute token probabilities for sample
s_tk_probs = (
target_model.get_probabilities(substr)
s_tk_probs, s_all_probs = (
target_model.get_probabilities(substr, return_all_probs=True)
if not config.pretokenized
else target_model.get_probabilities(
detokenized_sample[i], tokens=substr
detokenized_sample[i], tokens=substr, return_all_probs=True
)
)

Expand All @@ -150,17 +150,38 @@ def get_mia_scores(
continue

if attack != AllAttacks.NEIGHBOR:
score = attacker.attack(
substr,
probs=s_tk_probs,
detokenized_sample=(
detokenized_sample[i]
if config.pretokenized
else None
),
loss=loss,
)
sample_information[attack].append(score)
if attack in [AllAttacks.MIN_K, AllAttacks.MIN_K_PLUS_PLUS]:
# For Min-K and Min-K++
# iterate over k
for k in [
0.1, 0.2, 0.3, 0.4, 0.5,
0.6, 0.7, 0.8, 0.9, 1.0
]:
score = attacker.attack(
substr,
probs=s_tk_probs,
detokenized_sample=(
detokenized_sample[i]
if config.pretokenized
else None
),
loss=loss,
all_probs=s_all_probs, # for Min-K%++,
k=k
)
sample_information[f"{attack}_{k}"].append(score)
else:
score = attacker.attack(
substr,
probs=s_tk_probs,
detokenized_sample=(
detokenized_sample[i]
if config.pretokenized
else None
),
loss=loss,
)
sample_information[attack].append(score)
else:
# For each 'number of neighbors'
for n_perturbation in n_perturbation_list:
Expand Down

0 comments on commit 206c010

Please sign in to comment.