-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add ReCaLL attack #26
Changes from 11 commits
e7cbe89
dc6260f
84752ec
4510d96
4472fab
3cc6c9b
3f19c7e
0249ab9
793bdd7
475e66b
c701cc0
bdb3c06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
{ | ||
"experiment_name": "recall", | ||
"base_model": "EleutherAI/pythia-1.4b", | ||
"dataset_member": "the_pile", | ||
"dataset_nonmember": "the_pile", | ||
"min_words": 100, | ||
"max_words": 200, | ||
"max_tokens": 512, | ||
"max_data": 100000, | ||
"output_name": "unified_mia", | ||
"specific_source": "Github_ngram_13_<0.8_truncated", | ||
"n_samples": 1000, | ||
"recall_num_shots": 1, | ||
"blackbox_attacks": ["loss", "ref", "zlib", "min_k", "min_k++", "recall"], | ||
"env_config": { | ||
"results": "results_new", | ||
"device": "cuda:0", | ||
"device_aux": "cuda:0" | ||
}, | ||
"ref_config": { | ||
"models": [ | ||
"EleutherAI/pythia-160m" | ||
] | ||
}, | ||
"neighborhood_config": { | ||
"model": "bert", | ||
"n_perturbation_list": [ | ||
25 | ||
], | ||
"pct_words_masked": 0.3, | ||
"span_length": 2, | ||
"dump_cache": false, | ||
"load_from_cache": true, | ||
"neighbor_strategy": "random" | ||
}, | ||
"dump_cache": false, | ||
"load_from_cache": false, | ||
"load_from_hf": true | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
""" | ||
ReCaLL Attack: https://github.com/ruoyuxie/recall/ | ||
""" | ||
import torch | ||
import numpy as np | ||
from mimir.attacks.all_attacks import Attack | ||
from mimir.models import Model | ||
from mimir.config import ExperimentConfig | ||
|
||
class ReCaLLAttack(Attack): | ||
|
||
def __init__(self, config: ExperimentConfig, target_model: Model): | ||
super().__init__(config, target_model, ref_model = None) | ||
self.prefix = None | ||
|
||
@torch.no_grad() | ||
def _attack(self, document, probs, tokens = None, **kwargs): | ||
recall_dict: dict = kwargs.get("recall_dict", None) | ||
|
||
nonmember_prefix = recall_dict.get("prefix") | ||
num_shots = recall_dict.get("num_shots") | ||
avg_length = recall_dict.get("avg_length") | ||
|
||
assert nonmember_prefix, "nonmember_prefix should not be None or empty" | ||
|
||
lls = self.target_model.get_ll(document, probs = probs, tokens = tokens) | ||
ll_nonmember = self.get_conditional_ll(nonmember_prefix = nonmember_prefix, text = document, | ||
num_shots = num_shots, avg_length = avg_length, | ||
tokens = tokens) | ||
recall = ll_nonmember / lls | ||
|
||
|
||
assert not np.isnan(recall) | ||
return recall | ||
|
||
def process_prefix(self, prefix, avg_length, total_shots): | ||
model = self.target_model | ||
tokenizer = self.target_model.tokenizer | ||
|
||
if self.prefix is not None: | ||
# We only need to process the prefix once, after that we can just return | ||
return self.prefix | ||
|
||
max_length = model.max_length | ||
token_counts = [len(tokenizer.encode(shot)) for shot in prefix] | ||
|
||
target_token_count = avg_length | ||
total_tokens = sum(token_counts) + target_token_count | ||
if total_tokens<=max_length: | ||
self.prefix = prefix | ||
return self.prefix | ||
# Determine the maximum number of shots that can fit within the max_length | ||
max_shots = 0 | ||
cumulative_tokens = target_token_count | ||
for count in token_counts: | ||
if cumulative_tokens + count <= max_length: | ||
max_shots += 1 | ||
cumulative_tokens += count | ||
else: | ||
break | ||
# Truncate the prefix to include only the maximum number of shots | ||
truncated_prefix = prefix[-max_shots:] | ||
print(f"""\nToo many shots used. Initial ReCaLL number of shots was {total_shots}. Maximum number of shots is {max_shots}. Defaulting to maximum number of shots.""") | ||
self.prefix = truncated_prefix | ||
return self.prefix | ||
|
||
def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, tokens=None): | ||
assert nonmember_prefix, "nonmember_prefix should not be None or empty" | ||
|
||
model = self.target_model | ||
tokenizer = self.target_model.tokenizer | ||
|
||
if tokens is None: | ||
target_encodings = tokenizer(text=text, return_tensors="pt") | ||
else: | ||
target_encodings = tokens | ||
|
||
processed_prefix = self.process_prefix(nonmember_prefix, avg_length, total_shots=num_shots) | ||
input_encodings = tokenizer(text="".join(processed_prefix), return_tensors="pt") | ||
|
||
prefix_ids = input_encodings.input_ids.to(model.device) | ||
text_ids = target_encodings.input_ids.to(model.device) | ||
|
||
max_length = model.max_length | ||
|
||
if prefix_ids.size(1) >= max_length: | ||
raise ValueError("Prefix length exceeds or equals the model's maximum context window.") | ||
|
||
labels = torch.cat((prefix_ids, text_ids), dim=1) | ||
total_length = labels.size(1) | ||
|
||
total_loss = 0 | ||
total_tokens = 0 | ||
with torch.no_grad(): | ||
for i in range(0, total_length, max_length): | ||
begin_loc = i | ||
end_loc = min(i + max_length, total_length) | ||
trg_len = end_loc - begin_loc | ||
|
||
input_ids = labels[:, begin_loc:end_loc].to(model.device) | ||
target_ids = input_ids.clone() | ||
|
||
if begin_loc < prefix_ids.size(1): | ||
prefix_overlap = min(prefix_ids.size(1) - begin_loc, max_length) | ||
target_ids[:, :prefix_overlap] = -100 | ||
|
||
if end_loc > total_length - text_ids.size(1): | ||
target_overlap = min(end_loc - (total_length - text_ids.size(1)), max_length) | ||
target_ids[:, -target_overlap:] = input_ids[:, -target_overlap:] | ||
|
||
if torch.all(target_ids == -100): | ||
continue | ||
|
||
outputs = model.model(input_ids, labels=target_ids) | ||
loss = outputs.loss | ||
if torch.isnan(loss): | ||
print(f"NaN detected in loss at iteration {i}. Non masked target_ids size is {(target_ids != -100).sum().item()}") | ||
continue | ||
non_masked_tokens = (target_ids != -100).sum().item() | ||
total_loss += loss.item() * non_masked_tokens | ||
total_tokens += non_masked_tokens | ||
|
||
average_loss = total_loss / total_tokens if total_tokens > 0 else 0 | ||
return -average_loss | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -77,6 +77,7 @@ def get_mia_scores( | |
is_train: bool, | ||
n_samples: int = None, | ||
batch_size: int = 50, | ||
**kwargs | ||
): | ||
# Fix randomness | ||
fix_seed(config.random_seed) | ||
|
@@ -100,6 +101,14 @@ def get_mia_scores( | |
n_perturbation: [] for n_perturbation in n_perturbation_list | ||
} | ||
|
||
nonmember_prefix = kwargs.get("nonmember_prefix", None) | ||
if AllAttacks.RECALL in attackers_dict.keys(): | ||
if nonmember_prefix is None: | ||
raise ValueError("Must include a prefix for ReCaLL attack") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want this condition? |
||
num_shots = config.recall_num_shots | ||
avg_length = int(np.mean([len(target_model.tokenizer.encode(ex)) for ex in data["records"]])) | ||
recall_dict = {"prefix":nonmember_prefix, "num_shots":num_shots, "avg_length":avg_length} | ||
|
||
# For each batch of data | ||
# TODO: Batch-size isn't really "batching" data - change later | ||
for batch in tqdm(range(math.ceil(n_samples / batch_size)), desc=f"Computing criterion"): | ||
|
@@ -160,8 +169,10 @@ def get_mia_scores( | |
), | ||
loss=loss, | ||
all_probs=s_all_probs, | ||
recall_dict = recall_dict | ||
) | ||
sample_information[attack].append(score) | ||
|
||
else: | ||
# For each 'number of neighbors' | ||
for n_perturbation in n_perturbation_list: | ||
|
@@ -515,6 +526,21 @@ def main(config: ExperimentConfig): | |
mask_model_tokenizer=mask_model.tokenizer if mask_model else None, | ||
) | ||
|
||
#* ReCaLL Specific | ||
if AllAttacks.RECALL in config.blackbox_attacks: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the config has multiple attacks, |
||
num_shots = config.recall_num_shots | ||
nonmember_prefix = data_nonmember[:num_shots] | ||
nonmember_data = data_nonmember[num_shots:] | ||
|
||
member_prefix = data_member[:num_shots] | ||
member_data = data_member[num_shots:] | ||
|
||
data_nonmember = nonmember_data | ||
data_member = member_data | ||
else: | ||
nonmember_prefix = None | ||
|
||
|
||
other_objs, other_nonmembers = None, None | ||
if config.dataset_nonmember_other_sources is not None: | ||
other_objs, other_nonmembers = [], [] | ||
|
@@ -628,7 +654,8 @@ def main(config: ExperimentConfig): | |
ref_models=ref_models, | ||
config=config, | ||
is_train=True, | ||
n_samples=n_samples | ||
n_samples=n_samples, | ||
nonmember_prefix = nonmember_prefix | ||
) | ||
# Collect scores for non-members | ||
nonmember_preds, nonmember_samples = get_mia_scores( | ||
|
@@ -640,6 +667,7 @@ def main(config: ExperimentConfig): | |
config=config, | ||
is_train=False, | ||
n_samples=n_samples, | ||
nonmember_prefix = nonmember_prefix | ||
) | ||
blackbox_outputs = compute_metrics_from_scores( | ||
member_preds, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you create a separate class for Configuration (just like we have a separate
NeighborhoodConfig
for neighborhood attack) for this instead of adding it directly to the ExperimentConfig?