From 014d0ea96a922a807846d0fb35032889e294af66 Mon Sep 17 00:00:00 2001 From: Anshuman Suri Date: Tue, 6 Feb 2024 09:39:31 -0500 Subject: [PATCH] Updatee attack flow/logic in run.py --- .gitignore | 6 +- cross.py | 1031 ----------------------------- mimir/attacks/attack_utils.py | 14 +- mimir/attacks/blackbox_attacks.py | 13 +- mimir/attacks/neighborhood.py | 2 +- mimir/attacks/reference.py | 6 +- mimir/attacks/utils.py | 1 - mimir/models.py | 48 +- run.py | 597 ++++++++--------- 9 files changed, 332 insertions(+), 1386 deletions(-) delete mode 100644 cross.py diff --git a/.gitignore b/.gitignore index 188dc18..86d3d59 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ data/*/ # Logs (from cluster runs) logs/* +logs_cluster/* # Vscode .vscode/* @@ -54,4 +55,7 @@ quantile_ref_model/* # Temp table*/* fig7/* -new_mi/* \ No newline at end of file +new_mi/* + +# Ignore all tar.gz files +*.tar.gz \ No newline at end of file diff --git a/cross.py b/cross.py deleted file mode 100644 index 105a6a4..0000000 --- a/cross.py +++ /dev/null @@ -1,1031 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np -import datasets -import transformers -import re -import torch -import torch.nn.functional as F -import tqdm -import random -from sklearn.metrics import roc_curve, precision_recall_curve, auc -import argparse -import datetime -import os -import json -import functools -import mimir.custom_datasets as custom_datasets -from multiprocessing.pool import ThreadPool -import time - -# 15 colorblind-friendly colors -COLORS = ["#0072B2", "#009E73", "#D55E00", "#CC79A7", "#F0E442", - "#56B4E9", "#E69F00", "#000000", "#0072B2", "#009E73", - "#D55E00", "#CC79A7", "#F0E442", "#56B4E9", "#E69F00"] - -# define regex to match all tokens, where * is an integer -pattern = re.compile(r"") - - -def load_base_model(): - print('MOVING BASE MODEL TO GPU...', end='', flush=True) - start = time.time() - try: - mask_model.cpu() - except NameError: - pass - if args.openai_model is None: - base_model.to(DEVICE) - print(f'DONE ({time.time() - start:.2f}s)') - - -def load_mask_model(): - print('MOVING MASK MODEL TO GPU...', end='', flush=True) - start = time.time() - - if args.openai_model is None: - base_model.cpu() - if not args.random_fills: - mask_model.to(DEVICE) - print(f'DONE ({time.time() - start:.2f}s)') - - -def tokenize_and_mask(text, span_length, pct, ceil_pct=False): - tokens = text.split(' ') - mask_string = '<<>>' - - n_spans = pct * len(tokens) / (span_length + args.buffer_size * 2) - if ceil_pct: - n_spans = np.ceil(n_spans) - n_spans = int(n_spans) - - n_masks = 0 - while n_masks < n_spans: - start = np.random.randint(0, len(tokens) - span_length) - end = start + span_length - search_start = max(0, start - args.buffer_size) - search_end = min(len(tokens), end + args.buffer_size) - if mask_string not in tokens[search_start:search_end]: - tokens[start:end] = [mask_string] - n_masks += 1 - - # replace each occurrence of mask_string with , where NUM increments - num_filled = 0 - for idx, token in enumerate(tokens): - if token == mask_string: - tokens[idx] = f'' - num_filled += 1 - assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}" - text = ' '.join(tokens) - return text - - -def count_masks(texts): - return [len([x for x in text.split() if x.startswith("")[0] - tokens = mask_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE) - outputs = mask_model.generate(**tokens, max_length=150, do_sample=True, top_p=args.mask_top_p, num_return_sequences=1, eos_token_id=stop_id) - return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False) - - -def extract_fills(texts): - # remove from beginning of each text - texts = [x.replace("", "").replace("", "").strip() for x in texts] - - # return the text in between each matched mask token - extracted_fills = [pattern.split(x)[1:-1] for x in texts] - - # remove whitespace around each fill - extracted_fills = [[y.strip() for y in x] for x in extracted_fills] - - return extracted_fills - - -def apply_extracted_fills(masked_texts, extracted_fills): - # split masked text into tokens, only splitting on spaces (not newlines) - tokens = [x.split(' ') for x in masked_texts] - - n_expected = count_masks(masked_texts) - - # replace each mask token with the corresponding fill - for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)): - if len(fills) < n: - tokens[idx] = [] - else: - for fill_idx in range(n): - text[text.index(f"")] = fills[fill_idx] - - # join tokens back into text - texts = [" ".join(x) for x in tokens] - return texts - - -def perturb_texts_(texts, span_length, pct, ceil_pct=False): - if not args.random_fills: - masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts] - raw_fills = replace_masks(masked_texts) - extracted_fills = extract_fills(raw_fills) - perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills) - - # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again - attempts = 1 - while '' in perturbed_texts: - idxs = [idx for idx, x in enumerate(perturbed_texts) if x == ''] - print(f'WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}].') - masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for idx, x in enumerate(texts) if idx in idxs] - raw_fills = replace_masks(masked_texts) - extracted_fills = extract_fills(raw_fills) - new_perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills) - for idx, x in zip(idxs, new_perturbed_texts): - perturbed_texts[idx] = x - attempts += 1 - else: - if args.random_fills_tokens: - # tokenize base_tokenizer - tokens = base_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE) - valid_tokens = tokens.input_ids != base_tokenizer.pad_token_id - replace_pct = args.pct_words_masked * (args.span_length / (args.span_length + 2 * args.buffer_size)) - - # replace replace_pct of input_ids with random tokens - random_mask = torch.rand(tokens.input_ids.shape, device=DEVICE) < replace_pct - random_mask &= valid_tokens - random_tokens = torch.randint(0, base_tokenizer.vocab_size, (random_mask.sum(),), device=DEVICE) - # while any of the random tokens are special tokens, replace them with random non-special tokens - while any(base_tokenizer.decode(x) in base_tokenizer.all_special_tokens for x in random_tokens): - random_tokens = torch.randint(0, base_tokenizer.vocab_size, (random_mask.sum(),), device=DEVICE) - tokens.input_ids[random_mask] = random_tokens - perturbed_texts = base_tokenizer.batch_decode(tokens.input_ids, skip_special_tokens=True) - else: - masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts] - perturbed_texts = masked_texts - # replace each with args.span_length random words from FILL_DICTIONARY - for idx, text in enumerate(perturbed_texts): - filled_text = text - for fill_idx in range(count_masks([text])[0]): - fill = random.sample(FILL_DICTIONARY, span_length) - filled_text = filled_text.replace(f"", " ".join(fill)) - assert count_masks([filled_text])[0] == 0, "Failed to replace all masks" - perturbed_texts[idx] = filled_text - - return perturbed_texts - - -def perturb_texts(texts, span_length, pct, ceil_pct=False): - chunk_size = args.chunk_size - if '11b' in mask_filling_model_name: - chunk_size //= 2 - - outputs = [] - for i in tqdm.tqdm(range(0, len(texts), chunk_size), desc="Applying perturbations"): - outputs.extend(perturb_texts_(texts[i:i + chunk_size], span_length, pct, ceil_pct=ceil_pct)) - return outputs - - -def drop_last_word(text): - return ' '.join(text.split(' ')[:-1]) - - -def _openai_sample(p): - if args.dataset != 'pubmed': # keep Answer: prefix for pubmed - p = drop_last_word(p) - - # sample from the openai model - kwargs = { "engine": args.openai_model, "max_tokens": 200 } - if args.do_top_p: - kwargs['top_p'] = args.top_p - - r = openai.Completion.create(prompt=f"{p}", **kwargs) - return p + r['choices'][0].text - - -# sample from base_model using ****only**** the first 30 tokens in each example as context -def sample_from_model(texts, min_words=55, prompt_tokens=30): - # encode each text as a list of token ids - if args.dataset == 'pubmed': - texts = [t[:t.index(custom_datasets.SEPARATOR)] for t in texts] - all_encoded = base_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE) - else: - all_encoded = base_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE) - all_encoded = {key: value[:, :prompt_tokens] for key, value in all_encoded.items()} - - if args.openai_model: - # decode the prefixes back into text - prefixes = base_tokenizer.batch_decode(all_encoded['input_ids'], skip_special_tokens=True) - pool = ThreadPool(args.batch_size) - - decoded = pool.map(_openai_sample, prefixes) - else: - decoded = ['' for _ in range(len(texts))] - - # sample from the model until we get a sample with at least min_words words for each example - # this is an inefficient way to do this (since we regenerate for all inputs if just one is too short), but it works - tries = 0 - while (m := min(len(x.split()) for x in decoded)) < min_words: - if tries != 0: - print() - print(f"min words: {m}, needed {min_words}, regenerating (try {tries})") - - sampling_kwargs = {} - if args.do_top_p: - sampling_kwargs['top_p'] = args.top_p - elif args.do_top_k: - sampling_kwargs['top_k'] = args.top_k - min_length = 50 if args.dataset in ['pubmed'] else 150 - tries += 1 - - if args.openai_model: - global API_TOKEN_COUNTER - - # count total number of tokens with GPT2_TOKENIZER - total_tokens = sum(len(GPT2_TOKENIZER.encode(x)) for x in decoded) - API_TOKEN_COUNTER += total_tokens - - return decoded - - -def get_likelihood(logits, labels): - assert logits.shape[0] == 1 - assert labels.shape[0] == 1 - - logits = logits.view(-1, logits.shape[-1])[:-1] - labels = labels.view(-1)[1:] - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - log_likelihood = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) - return log_likelihood.mean() - - -# Get the log likelihood of each text under the base_model -def get_ll(text): - if args.openai_model: - kwargs = { "engine": args.openai_model, "temperature": 0, "max_tokens": 0, "echo": True, "logprobs": 0} - r = openai.Completion.create(prompt=f"<|endoftext|>{text}", **kwargs) - result = r['choices'][0] - tokens, logprobs = result["logprobs"]["tokens"][1:], result["logprobs"]["token_logprobs"][1:] - - assert len(tokens) == len(logprobs), f"Expected {len(tokens)} logprobs, got {len(logprobs)}" - - return np.mean(logprobs) - else: - with torch.no_grad(): - tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE) - labels = tokenized.input_ids - return -base_model(**tokenized, labels=labels).loss.item() - - - -# Get the likelihood ratio of each text under the base_model -- MIA baseline -def get_lira(text): - if args.openai_model: - print("NOT IMPLEMENTED") - exit(0) - kwargs = { "engine": args.openai_model, "temperature": 0, "max_tokens": 0, "echo": True, "logprobs": 0} - r = openai.Completion.create(prompt=f"<|endoftext|>{text}", **kwargs) - result = r['choices'][0] - tokens, logprobs = result["logprobs"]["tokens"][1:], result["logprobs"]["token_logprobs"][1:] - - assert len(tokens) == len(logprobs), f"Expected {len(tokens)} logprobs, got {len(logprobs)}" - - return np.mean(logprobs) - else: - with torch.no_grad(): - tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE) - labels = tokenized.input_ids - tokenized_ref = ref_tokenizer(text, return_tensors="pt").to(DEVICE) - lls = -base_model(**tokenized, labels=labels).loss.item() - lls_ref = -ref_model(**tokenized_ref, labels=labels).loss.item() - - return lls - lls_ref - - - -def get_lls(texts): - if not args.openai_model: - return [get_ll(text) for text in texts] - else: - global API_TOKEN_COUNTER - - # use GPT2_TOKENIZER to get total number of tokens - total_tokens = sum(len(GPT2_TOKENIZER.encode(text)) for text in texts) - API_TOKEN_COUNTER += total_tokens * 2 # multiply by two because OpenAI double-counts echo_prompt tokens - - pool = ThreadPool(args.batch_size) - return pool.map(get_ll, texts) - - -# get the average rank of each observed token sorted by model likelihood -def get_rank(text, log=False): - assert args.openai_model is None, "get_rank not implemented for OpenAI models" - - with torch.no_grad(): - tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE) - logits = base_model(**tokenized).logits[:,:-1] - labels = tokenized.input_ids[:,1:] - - # get rank of each label token in the model's likelihood ordering - matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero() - - assert matches.shape[1] == 3, f"Expected 3 dimensions in matches tensor, got {matches.shape}" - - ranks, timesteps = matches[:,-1], matches[:,-2] - - # make sure we got exactly one match for each timestep in the sequence - assert (timesteps == torch.arange(len(timesteps)).to(timesteps.device)).all(), "Expected one match per timestep" - - ranks = ranks.float() + 1 # convert to 1-indexed rank - if log: - ranks = torch.log(ranks) - - return ranks.float().mean().item() - - -# get average entropy of each token in the text -def get_entropy(text): - assert args.openai_model is None, "get_entropy not implemented for OpenAI models" - - with torch.no_grad(): - tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE) - logits = base_model(**tokenized).logits[:,:-1] - neg_entropy = F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1) - return -neg_entropy.sum(-1).mean().item() - - -def get_roc_metrics(real_preds, sample_preds): - fpr, tpr, _ = roc_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds) - roc_auc = auc(fpr, tpr) - return fpr.tolist(), tpr.tolist(), float(roc_auc) - - -def get_precision_recall_metrics(real_preds, sample_preds): - precision, recall, _ = precision_recall_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds) - pr_auc = auc(recall, precision) - return precision.tolist(), recall.tolist(), float(pr_auc) - - -# save the ROC curve for each experiment, given a list of output dictionaries, one for each experiment, using colorblind-friendly colors -def save_roc_curves(experiments): - # first, clear plt - plt.clf() - - for experiment, color in zip(experiments, COLORS): - metrics = experiment["metrics"] - plt.plot(metrics["fpr"], metrics["tpr"], label=f"{experiment['name']}, roc_auc={metrics['roc_auc']:.3f}", color=color) - # print roc_auc for this experiment - print(f"{experiment['name']} roc_auc: {metrics['roc_auc']:.3f}") - plt.plot([0, 1], [0, 1], color='black', lw=2, linestyle='--') - plt.xlim([0.0, 1.0]) - plt.ylim([0.0, 1.05]) - plt.xlabel('False Positive Rate') - plt.ylabel('True Positive Rate') - plt.title(f'ROC Curves ({base_model_name} - {args.mask_filling_model_name})') - plt.legend(loc="lower right", fontsize=6) - plt.savefig(f"{SAVE_FOLDER}/roc_curves.png") - - -# save the histogram of log likelihoods in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed -def save_ll_histograms(experiments): - # first, clear plt - plt.clf() - - for experiment in experiments: - try: - results = experiment["raw_results"] - # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right - plt.figure(figsize=(20, 6)) - plt.subplot(1, 2, 1) - plt.hist([r["sampled_ll"] for r in results], alpha=0.5, bins='auto', label='sampled') - plt.hist([r["perturbed_sampled_ll"] for r in results], alpha=0.5, bins='auto', label='perturbed sampled') - plt.xlabel("log likelihood") - plt.ylabel('count') - plt.legend(loc='upper right') - plt.subplot(1, 2, 2) - plt.hist([r["original_ll"] for r in results], alpha=0.5, bins='auto', label='original') - plt.hist([r["perturbed_original_ll"] for r in results], alpha=0.5, bins='auto', label='perturbed original') - plt.xlabel("log likelihood") - plt.ylabel('count') - plt.legend(loc='upper right') - plt.savefig(f"{SAVE_FOLDER}/ll_histograms_{experiment['name']}.png") - except: - pass - - -# save the histograms of log likelihood ratios in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed -def save_llr_histograms(experiments): - # first, clear plt - plt.clf() - - for experiment in experiments: - try: - results = experiment["raw_results"] - # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right - plt.figure(figsize=(20, 6)) - plt.subplot(1, 2, 1) - - # compute the log likelihood ratio for each result - for r in results: - r["sampled_llr"] = r["sampled_ll"] - r["perturbed_sampled_ll"] - r["original_llr"] = r["original_ll"] - r["perturbed_original_ll"] - - plt.hist([r["sampled_llr"] for r in results], alpha=0.5, bins='auto', label='sampled') - plt.hist([r["original_llr"] for r in results], alpha=0.5, bins='auto', label='original') - plt.xlabel("log likelihood ratio") - plt.ylabel('count') - plt.legend(loc='upper right') - plt.savefig(f"{SAVE_FOLDER}/llr_histograms_{experiment['name']}.png") - except: - pass - - -def get_perturbation_results(span_length=10, n_perturbations=1, n_samples=500): - load_mask_model() - - torch.manual_seed(0) - np.random.seed(0) - - results = [] - original_text = data["original"] - sampled_text = data["sampled"] - - perturb_fn = functools.partial(perturb_texts, span_length=span_length, pct=args.pct_words_masked) - - p_sampled_text = perturb_fn([x for x in sampled_text for _ in range(n_perturbations)]) - p_original_text = perturb_fn([x for x in original_text for _ in range(n_perturbations)]) - for _ in range(n_perturbation_rounds - 1): - try: - p_sampled_text, p_original_text = perturb_fn(p_sampled_text), perturb_fn(p_original_text) - except AssertionError: - break - - assert len(p_sampled_text) == len(sampled_text) * n_perturbations, f"Expected {len(sampled_text) * n_perturbations} perturbed samples, got {len(p_sampled_text)}" - assert len(p_original_text) == len(original_text) * n_perturbations, f"Expected {len(original_text) * n_perturbations} perturbed samples, got {len(p_original_text)}" - - for idx in range(len(original_text)): - results.append({ - "original": original_text[idx], - "sampled": sampled_text[idx], - "perturbed_sampled": p_sampled_text[idx * n_perturbations: (idx + 1) * n_perturbations], - "perturbed_original": p_original_text[idx * n_perturbations: (idx + 1) * n_perturbations] - }) - - load_base_model() - - for res in tqdm.tqdm(results, desc="Computing log likelihoods"): - p_sampled_ll = get_lls(res["perturbed_sampled"]) - p_original_ll = get_lls(res["perturbed_original"]) - res["original_ll"] = get_ll(res["original"]) - res["sampled_ll"] = get_ll(res["sampled"]) - res["all_perturbed_sampled_ll"] = p_sampled_ll - res["all_perturbed_original_ll"] = p_original_ll - res["perturbed_sampled_ll"] = np.mean(p_sampled_ll) - res["perturbed_original_ll"] = np.mean(p_original_ll) - res["perturbed_sampled_ll_std"] = np.std(p_sampled_ll) if len(p_sampled_ll) > 1 else 1 - res["perturbed_original_ll_std"] = np.std(p_original_ll) if len(p_original_ll) > 1 else 1 - - return results - - -def run_perturbation_experiment(results, criterion, span_length=10, n_perturbations=1, n_samples=500): - # compute diffs with perturbed - predictions = {'real': [], 'samples': []} - for res in results: - if criterion == 'd': - predictions['real'].append(res['original_ll'] - res['perturbed_original_ll']) - predictions['samples'].append(res['sampled_ll'] - res['perturbed_sampled_ll']) - elif criterion == 'z': - if res['perturbed_original_ll_std'] == 0: - res['perturbed_original_ll_std'] = 1 - print("WARNING: std of perturbed original is 0, setting to 1") - print(f"Number of unique perturbed original texts: {len(set(res['perturbed_original']))}") - print(f"Original text: {res['original']}") - if res['perturbed_sampled_ll_std'] == 0: - res['perturbed_sampled_ll_std'] = 1 - print("WARNING: std of perturbed sampled is 0, setting to 1") - print(f"Number of unique perturbed sampled texts: {len(set(res['perturbed_sampled']))}") - print(f"Sampled text: {res['sampled']}") - predictions['real'].append((res['original_ll'] - res['perturbed_original_ll']) / res['perturbed_original_ll_std']) - predictions['samples'].append((res['sampled_ll'] - res['perturbed_sampled_ll']) / res['perturbed_sampled_ll_std']) - - fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples']) - p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples']) - name = f'perturbation_{n_perturbations}_{criterion}' - print(f"{name} ROC AUC: {roc_auc}, PR AUC: {pr_auc}") - return { - 'name': name, - 'predictions': predictions, - 'info': { - 'pct_words_masked': args.pct_words_masked, - 'span_length': span_length, - 'n_perturbations': n_perturbations, - 'n_samples': n_samples, - }, - 'raw_results': results, - 'metrics': { - 'roc_auc': roc_auc, - 'fpr': fpr, - 'tpr': tpr, - }, - 'pr_metrics': { - 'pr_auc': pr_auc, - 'precision': p, - 'recall': r, - }, - 'loss': 1 - pr_auc, - } - - -def run_baseline_threshold_experiment(criterion_fn, name, n_samples=500): - torch.manual_seed(0) - np.random.seed(0) - - results = [] - for batch in tqdm.tqdm(range(n_samples // batch_size), desc=f"Computing {name} criterion"): - original_text = data["original"][batch * batch_size:(batch + 1) * batch_size] - sampled_text = data["sampled"][batch * batch_size:(batch + 1) * batch_size] - - for idx in range(len(original_text)): - results.append({ - "original": original_text[idx], - "original_crit": criterion_fn(original_text[idx]), - "sampled": sampled_text[idx], - "sampled_crit": criterion_fn(sampled_text[idx]), - }) - - # compute prediction scores for real/sampled passages - predictions = { - 'real': [x["original_crit"] for x in results], - 'samples': [x["sampled_crit"] for x in results], - } - - fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples']) - p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples']) - print(f"{name}_threshold ROC AUC: {roc_auc}, PR AUC: {pr_auc}") - return { - 'name': f'{name}_threshold', - 'predictions': predictions, - 'info': { - 'n_samples': n_samples, - }, - 'raw_results': results, - 'metrics': { - 'roc_auc': roc_auc, - 'fpr': fpr, - 'tpr': tpr, - }, - 'pr_metrics': { - 'pr_auc': pr_auc, - 'precision': p, - 'recall': r, - }, - 'loss': 1 - pr_auc, - } - - -# strip newlines from each example; replace one or more newlines with a single space -def strip_newlines(text): - return ' '.join(text.split()) - - -# trim to shorter length -def trim_to_shorter_length(texta, textb): - # truncate to shorter of o and s - shorter_length = min(len(texta.split(' ')), len(textb.split(' '))) - texta = ' '.join(texta.split(' ')[:shorter_length]) - textb = ' '.join(textb.split(' ')[:shorter_length]) - return texta, textb - - -def truncate_to_substring(text, substring, idx_occurrence): - # truncate everything after the idx_occurrence occurrence of substring - assert idx_occurrence > 0, 'idx_occurrence must be > 0' - idx = -1 - for _ in range(idx_occurrence): - idx = text.find(substring, idx + 1) - if idx == -1: - return text - return text[:idx] - - -def generate_samples(raw_data, batch_size): - torch.manual_seed(42) - np.random.seed(42) - data = { - "original": [], - "sampled": [], - } - - for batch in range(len(raw_data) // batch_size): - print('Generating samples for batch', batch, 'of', len(raw_data) // batch_size) - original_text = raw_data[batch * batch_size:(batch + 1) * batch_size] - sampled_text = sample_from_model(original_text, min_words=30 if args.dataset in ['pubmed'] else 55) - - for o, s in zip(original_text, sampled_text): - if args.dataset == 'pubmed': - s = truncate_to_substring(s, 'Question:', 2) - o = o.replace(custom_datasets.SEPARATOR, ' ') - - o, s = trim_to_shorter_length(o, s) - - # add to the data - data["original"].append(o) - data["sampled"].append(s) - - if args.pre_perturb_pct > 0: - print(f'APPLYING {args.pre_perturb_pct}, {args.pre_perturb_span_length} PRE-PERTURBATIONS') - load_mask_model() - data["sampled"] = perturb_texts(data["sampled"], args.pre_perturb_span_length, args.pre_perturb_pct, ceil_pct=True) - load_base_model() - - return data - - -def generate_data(dataset, key): - # load data - if dataset in custom_datasets.DATASETS: - data = custom_datasets.load(dataset, cache_dir) - else: - data = datasets.load_dataset(dataset, split='train', cache_dir=cache_dir)[key] - - # get unique examples, strip whitespace, and remove newlines - # then take just the long examples, shuffle, take the first 5,000 to tokenize to save time - # then take just the examples that are <= 512 tokens (for the mask model) - # then generate n_samples samples - - # remove duplicates from the data - data = list(dict.fromkeys(data)) # deterministic, as opposed to set() - - # strip whitespace around each example - data = [x.strip() for x in data] - - # remove newlines from each example - data = [strip_newlines(x) for x in data] - - # try to keep only examples with > 250 words - if dataset in ['writing', 'squad', 'xsum']: - long_data = [x for x in data if len(x.split()) > 250] - if len(long_data) > 0: - data = long_data - - random.seed(0) - random.shuffle(data) - - data = data[:5_000] - - # keep only examples with <= 512 tokens according to mask_tokenizer - # this step has the extra effect of removing examples with low-quality/garbage content - tokenized_data = preproc_tokenizer(data) - data = [x for x, y in zip(data, tokenized_data["input_ids"]) if len(y) <= 512] - - # print stats about remainining data - print(f"Total number of samples: {len(data)}") - print(f"Average number of words: {np.mean([len(x.split()) for x in data])}") - - return generate_samples(data[:n_samples], batch_size=batch_size) - - -def load_data_gens(gen_a_file, gen_b_file, original_b=False,original_a=False): - - data = {'original':[],'sampled':[]} - - if original_b and not original_a: - for line_b in open(gen_b_file,'r'): - di_b= json.loads(line_b) - - #we only want the sampled stuff - gen_a is original and gen_b is sampled - - for element in di_b['original']: - data['original'].append(element) - - for element in di_b['sampled']: - data['sampled'].append(element) - - elif original_b and original_a: - - for line_a , line_b in zip(open(gen_a_file,'r'),open(gen_b_file,'r')): - di_a= json.loads(line_a) - di_b= json.loads(line_b) - - #we only want the sampled stuff - gen_a is original and gen_b is sampled - - for element in di_a['original']: - data['original'].append(element) - - for element in di_b['original']: - data['sampled'].append(element) - - elif original_a and not original_b: - - - for line_a , line_b in zip(open(gen_a_file,'r'),open(gen_b_file,'r')): - di_a= json.loads(line_a) - di_b= json.loads(line_b) - - #we only want the sampled stuff - gen_a is original and gen_b is sampled - - for element in di_a['original']: - data['original'].append(element) - - for element in di_b['sampled']: - data['sampled'].append(element) - - - else: - - for line_a , line_b in zip(open(gen_a_file,'r'),open(gen_b_file,'r')): - di_a= json.loads(line_a) - di_b= json.loads(line_b) - - #we only want the sampled stuff - gen_a is original and gen_b is sampled - - for element in di_a['sampled']: - data['original'].append(element) - - for element in di_b['sampled']: - data['sampled'].append(element) - - return data - -def load_base_model_and_tokenizer(name): - if args.openai_model is None: - print(f'Loading BASE model {args.base_model_name}...') - base_model_kwargs = {'revision':args.revision} - if 'gpt-j' in name or 'neox' in name: - base_model_kwargs.update(dict(torch_dtype=torch.float16)) - if 'gpt-j' in name: - base_model_kwargs.update(dict(revision='float16')) - base_model = transformers.AutoModelForCausalLM.from_pretrained(name, **base_model_kwargs, cache_dir=cache_dir) - else: - base_model = None - - optional_tok_kwargs = {} - if "facebook/opt-" in name: - print("Using non-fast tokenizer for OPT") - optional_tok_kwargs['fast'] = False - if args.dataset in ['pubmed']: - optional_tok_kwargs['padding_side'] = 'left' - base_tokenizer = transformers.AutoTokenizer.from_pretrained(name, **optional_tok_kwargs, cache_dir=cache_dir) - base_tokenizer.pad_token_id = base_tokenizer.eos_token_id - - return base_model, base_tokenizer - - -def eval_supervised(data, model): - print(f'Beginning supervised evaluation with {model}...') - detector = transformers.AutoModelForSequenceClassification.from_pretrained(model, cache_dir=cache_dir).to(DEVICE) - tokenizer = transformers.AutoTokenizer.from_pretrained(model, cache_dir=cache_dir) - - real, fake = data['original'], data['sampled'] - - with torch.no_grad(): - # get predictions for real - real_preds = [] - for batch in tqdm.tqdm(range(len(real) // batch_size), desc="Evaluating real"): - batch_real = real[batch * batch_size:(batch + 1) * batch_size] - batch_real = tokenizer(batch_real, padding=True, truncation=True, max_length=512, return_tensors="pt").to(DEVICE) - real_preds.extend(detector(**batch_real).logits.softmax(-1)[:,0].tolist()) - - # get predictions for fake - fake_preds = [] - for batch in tqdm.tqdm(range(len(fake) // batch_size), desc="Evaluating fake"): - batch_fake = fake[batch * batch_size:(batch + 1) * batch_size] - batch_fake = tokenizer(batch_fake, padding=True, truncation=True, max_length=512, return_tensors="pt").to(DEVICE) - fake_preds.extend(detector(**batch_fake).logits.softmax(-1)[:,0].tolist()) - - predictions = { - 'real': real_preds, - 'samples': fake_preds, - } - - fpr, tpr, roc_auc = get_roc_metrics(real_preds, fake_preds) - p, r, pr_auc = get_precision_recall_metrics(real_preds, fake_preds) - print(f"{model} ROC AUC: {roc_auc}, PR AUC: {pr_auc}") - - # free GPU memory - del detector - torch.cuda.empty_cache() - - return { - 'name': model, - 'predictions': predictions, - 'info': { - 'n_samples': n_samples, - }, - 'metrics': { - 'roc_auc': roc_auc, - 'fpr': fpr, - 'tpr': tpr, - }, - 'pr_metrics': { - 'pr_auc': pr_auc, - 'precision': p, - 'recall': r, - }, - 'loss': 1 - pr_auc, - } - - -if __name__ == '__main__': - DEVICE = "cuda" - - parser = argparse.ArgumentParser() - parser.add_argument('--dataset', type=str, default="xsum") - parser.add_argument('--dataset_key', type=str, default="document") - parser.add_argument('--pct_words_masked', type=float, default=0.3) # pct masked is actually pct_words_masked * (span_length / (span_length + 2 * buffer_size)) - parser.add_argument('--span_length', type=int, default=2) - parser.add_argument('--n_samples', type=int, default=200) - parser.add_argument('--n_perturbation_list', type=str, default="1,10") - parser.add_argument('--n_perturbation_rounds', type=int, default=1) - parser.add_argument('--base_model_name', type=str, default="gpt2-medium") - parser.add_argument('--revision', type=str, default="main") - parser.add_argument('--scoring_model_name', type=str, default="") - parser.add_argument('--mask_filling_model_name', type=str, default="t5-large") - parser.add_argument('--batch_size', type=int, default=50) - parser.add_argument('--chunk_size', type=int, default=20) - parser.add_argument('--n_similarity_samples', type=int, default=20) - parser.add_argument('--int8', action='store_true') - parser.add_argument('--half', action='store_true') - parser.add_argument('--base_half', action='store_true') - parser.add_argument('--do_top_k', action='store_true') - parser.add_argument('--top_k', type=int, default=40) - parser.add_argument('--do_top_p', action='store_true') - parser.add_argument('--top_p', type=float, default=0.96) - parser.add_argument('--output_name', type=str, default="") - parser.add_argument('--openai_model', type=str, default=None) - parser.add_argument('--openai_key', type=str) - parser.add_argument('--baselines_only', action='store_true') - parser.add_argument('--skip_baselines', action='store_true') - parser.add_argument('--buffer_size', type=int, default=1) - parser.add_argument('--mask_top_p', type=float, default=1.0) - parser.add_argument('--pre_perturb_pct', type=float, default=0.0) - parser.add_argument('--pre_perturb_span_length', type=int, default=5) - parser.add_argument('--random_fills', action='store_true') - parser.add_argument('--random_fills_tokens', action='store_true') - parser.add_argument('--cache_dir', type=str, default="/trunk/model-hub") - - #cross stuff - parser.add_argument('--root_gen', type=str, default="results/main/") - parser.add_argument('--gen_a', type=str, default="") - parser.add_argument('--gen_b', type=str, default="") - parser.add_argument('--original_b', type=bool, default=False) - parser.add_argument('--original_a', type=bool, default=False) - - # lira stuff - parser.add_argument('--ref_model', type=str, default=None) - - args = parser.parse_args() - - API_TOKEN_COUNTER = 0 - - if args.openai_model is not None: - import openai - assert args.openai_key is not None, "Must provide OpenAI API key as --openai_key" - openai.api_key = args.openai_key - - START_DATE = datetime.datetime.now().strftime('%Y-%m-%d') - START_TIME = datetime.datetime.now().strftime('%H-%M-%S-%f') - - # define SAVE_FOLDER as the timestamp - base model name - mask filling model name - # create it if it doesn't exist - precision_string = "int8" if args.int8 else ("fp16" if args.half else "fp32") - sampling_string = "top_k" if args.do_top_k else ("top_p" if args.do_top_p else "temp") - output_subfolder = f"{args.output_name}/" if args.output_name else "" - if args.openai_model is None: - base_model_name = args.base_model_name.replace('/', '_') - else: - base_model_name = "openai-" + args.openai_model.replace('/', '_') - scoring_model_string = (f"-{args.scoring_model_name}" if args.scoring_model_name else "").replace('/', '_') - if args.original_a and args.original_b: - suf_org = 'org_org' - elif args.original_a and not args.original_b: - suf_org="orga" - elif args.original_b and not args.original_a: - suf_org="" #TODO OOOO - else: - suf_org = "gen_gen" - - n_perturbation_list = [int(x) for x in args.n_perturbation_list.split(",")] - if len(n_perturbation_list) == 1 and n_perturbation_list[0] == 25: - suf_pert = "" - else: - suf_pert = f"_{n_perturbation_list[0]}" - - SAVE_FOLDER = f"tmp_results_cross/{output_subfolder}{base_model_name}-{args.revision}{scoring_model_string}-{args.gen_a.replace('/','_')}-{args.gen_b.replace('/','_')}{suf_org}{suf_pert}" - - print(SAVE_FOLDER) - - - new_folder = SAVE_FOLDER.replace("tmp_results_cross", "results_cross") - ##don't run if exists!!! - print(f"{new_folder}") - if os.path.isdir((new_folder)): - print(f"folder exists, not running this exp {new_folder}") - exit(0) - - - gen_a_file = f'{args.root_gen}/{args.gen_a}/raw_data.json' - gen_b_file = f'{args.root_gen}/{args.gen_b}/raw_data.json' - - if not os.path.exists(SAVE_FOLDER): - os.makedirs(SAVE_FOLDER) - print(f"Saving results to absolute path: {os.path.abspath(SAVE_FOLDER)}") - - # write args to file - with open(os.path.join(SAVE_FOLDER, "args.json"), "w") as f: - json.dump(args.__dict__, f, indent=4) - - mask_filling_model_name = args.mask_filling_model_name - n_samples = args.n_samples - batch_size = args.batch_size - n_perturbation_list = [int(x) for x in args.n_perturbation_list.split(",")] - n_perturbation_rounds = args.n_perturbation_rounds - n_similarity_samples = args.n_similarity_samples - - cache_dir = args.cache_dir - os.environ["XDG_CACHE_HOME"] = cache_dir - if not os.path.exists(cache_dir): - os.makedirs(cache_dir) - print(f"Using cache dir {cache_dir}") - - GPT2_TOKENIZER = transformers.GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir) - - # generic generative model - base_model, base_tokenizer = load_base_model_and_tokenizer(args.base_model_name) - - #reference model if we are doing the lr baseline - if args.ref_model is not None : - ref_model, ref_tokenizer = load_base_model_and_tokenizer(args.ref_model) - - # mask filling t5 model - if not args.baselines_only and not args.random_fills: - int8_kwargs = {} - half_kwargs = {} - if args.int8: - int8_kwargs = dict(load_in_8bit=True, device_map='auto', torch_dtype=torch.bfloat16) - elif args.half: - half_kwargs = dict(torch_dtype=torch.bfloat16) - print(f'Loading mask filling model {mask_filling_model_name}...') - mask_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(mask_filling_model_name, **int8_kwargs, **half_kwargs, cache_dir=cache_dir) - try: - n_positions = mask_model.config.n_positions - except AttributeError: - n_positions = 512 - else: - n_positions = 512 - preproc_tokenizer = transformers.AutoTokenizer.from_pretrained('t5-small', model_max_length=512, cache_dir=cache_dir) - mask_tokenizer = transformers.AutoTokenizer.from_pretrained(mask_filling_model_name, model_max_length=n_positions, cache_dir=cache_dir) - if args.dataset in ['english', 'german']: - preproc_tokenizer = mask_tokenizer - - load_base_model() - - # print(f'Loading dataset {args.dataset}...') - # data = generate_data(args.dataset, args.dataset_key) - #load the generations+samples - - data = load_data_gens(gen_a_file,gen_b_file,args.original_b,args.original_a) #we will return gen_a as original (human) and gen_b as the generated text - if args.random_fills: - FILL_DICTIONARY = set() - for texts in data.values(): - for text in texts: - FILL_DICTIONARY.update(text.split()) - FILL_DICTIONARY = sorted(list(FILL_DICTIONARY)) - - - - if args.scoring_model_name: - print(f'Loading SCORING model {args.scoring_model_name}...') - del base_model - del base_tokenizer - torch.cuda.empty_cache() - base_model, base_tokenizer = load_base_model_and_tokenizer(args.scoring_model_name) - load_base_model() # Load again because we've deleted/replaced the old model - - - outputs = [] - - if not args.baselines_only: - # run perturbation experiments - for n_perturbations in n_perturbation_list: - perturbation_results = get_perturbation_results(args.span_length, n_perturbations, n_samples) - for perturbation_mode in ['d', 'z']: - output = run_perturbation_experiment( - perturbation_results, perturbation_mode, span_length=args.span_length, n_perturbations=n_perturbations, n_samples=n_samples) - outputs.append(output) - with open(os.path.join(SAVE_FOLDER, f"perturbation_{n_perturbations}_{perturbation_mode}_results.json"), "w") as f: - json.dump(output, f) - - - - save_roc_curves(outputs) - save_ll_histograms(outputs) - save_llr_histograms(outputs) - - # move results folder from tmp_results/ to results/, making sure necessary directories exist - new_folder = SAVE_FOLDER.replace("tmp_results_cross", "results_cross") - if not os.path.exists(os.path.dirname(new_folder)): - os.makedirs(os.path.dirname(new_folder)) - os.rename(SAVE_FOLDER, new_folder) - - print(f"Used an *estimated* {API_TOKEN_COUNTER} API tokens (may be inaccurate)") \ No newline at end of file diff --git a/mimir/attacks/attack_utils.py b/mimir/attacks/attack_utils.py index 88d5e15..082bfa8 100644 --- a/mimir/attacks/attack_utils.py +++ b/mimir/attacks/attack_utils.py @@ -35,18 +35,10 @@ def apply_extracted_fills(masked_texts: List[str], extracted_fills): return texts -def get_likelihood(logits, labels): - assert logits.shape[0] == 1 - assert labels.shape[0] == 1 - - logits = logits.view(-1, logits.shape[-1])[:-1] - labels = labels.view(-1)[1:] - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - log_likelihood = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) - return log_likelihood.mean() - - def f1_score(prediction, ground_truth): + """ + Compute F1 score for given prediction and ground truth. + """ common = Counter(prediction) & Counter(ground_truth) num_same = sum(common.values()) if num_same == 0: diff --git a/mimir/attacks/blackbox_attacks.py b/mimir/attacks/blackbox_attacks.py index a4b9c03..15e4000 100644 --- a/mimir/attacks/blackbox_attacks.py +++ b/mimir/attacks/blackbox_attacks.py @@ -22,13 +22,19 @@ def __init__(self, config, target_model: Model, ref_model: Model = None): self.config = config self.target_model = target_model self.ref_model = ref_model + self.is_loaded = False - def prepare(self, **kwargs): + def load(self): """ Any attack-specific steps (one-time) preparation """ pass + def unload(self): + if self.ref_model is not None: + self.ref_model.unload() + self.is_loaded = False + def _attack(self, document, probs, tokens=None, **kwargs): """ Actual logic for attack. @@ -39,6 +45,11 @@ def attack(self, document, probs, **kwargs): """ Score a document using the attack's scoring function. Calls self._attack """ + # Load attack if not loaded yet + if not self.is_loaded: + self.load() + self.is_loaded = True + detokenized_sample = kwargs.get("detokenized_sample", None) if self.config.pretokenized and detokenized_sample is None: raise ValueError("detokenized_sample must be provided") diff --git a/mimir/attacks/neighborhood.py b/mimir/attacks/neighborhood.py index 9b92194..9808b8a 100644 --- a/mimir/attacks/neighborhood.py +++ b/mimir/attacks/neighborhood.py @@ -68,7 +68,7 @@ def _pick_neighbor_model(self): raise ValueError(f"Unknown model {self.config.neighborhood_config.model}") return mask_model - def prepare(self, **kwargs): + def load(self): """ Any attack-specific steps (one-time) preparation """ diff --git a/mimir/attacks/reference.py b/mimir/attacks/reference.py index 0a73f5e..51be011 100644 --- a/mimir/attacks/reference.py +++ b/mimir/attacks/reference.py @@ -8,12 +8,12 @@ class ReferenceAttack(Attack): def __init__(self, config, model, reference_model): super().__init__(config, model, reference_model) - def prepare(self, **kwargs): - self.reference_model.load() + def load(self): + self.ref_model.load() def _attack(self, document, probs, tokens=None, **kwargs): loss = kwargs.get('loss', None) if loss is None: loss = self.model.get_ll(document, probs=probs, tokens=tokens) - ref_loss = self.reference_model.get_ll(document, probs=probs, tokens=tokens) + ref_loss = self.ref_model.get_ll(document, probs=probs, tokens=tokens) return ref_loss - loss diff --git a/mimir/attacks/utils.py b/mimir/attacks/utils.py index b76dab6..1339e55 100644 --- a/mimir/attacks/utils.py +++ b/mimir/attacks/utils.py @@ -7,7 +7,6 @@ from mimir.attacks.neighborhood import NeighborhoodAttack - # TODO Use decorators to link attack implementations with enum above def get_attacker(attack: str): mapping = { diff --git a/mimir/models.py b/mimir/models.py index 1c9636d..09c9a50 100644 --- a/mimir/models.py +++ b/mimir/models.py @@ -34,13 +34,13 @@ def __init__(self, config: ExperimentConfig, **kwargs): self.name = None self.kwargs = kwargs self.cache_dir = self.config.env_config.cache_dir - + def to(self, device): """ Shift model to a particular device. """ - self.model.to(device) - + self.model.to(device, non_blocking=True) + def load(self): """ Load model onto GPU (and compile, if requested) if not already loaded with device map. @@ -56,7 +56,7 @@ def load(self): if self.config.env_config.compile: torch.compile(self.model) print(f'DONE ({time.time() - start:.2f}s)') - + def unload(self): """ Unload model from GPU @@ -110,7 +110,7 @@ def get_probabilities(self, text: str, tokens=None): assert len(all_prob) == labels.size(1) - 1 return all_prob - + @torch.no_grad() def get_ll(self, text: str, tokens=None, probs=None): """ @@ -118,7 +118,7 @@ def get_ll(self, text: str, tokens=None, probs=None): """ all_prob = probs if probs is not None else self.get_probabilities(text, tokens=tokens) return -np.mean(all_prob) - + def load_base_model_and_tokenizer(self, model_kwargs): """ Load the base model and tokenizer for a given model name. @@ -135,7 +135,7 @@ def load_base_model_and_tokenizer(self, model_kwargs): self.name, **model_kwargs, device_map=self.device, cache_dir=self.cache_dir) # Extract the model from the model wrapper so we dont need to call model.model elif "llama" in self.name or "alpaca" in self.name: - # TODO: This should be smth specified in config in case user has + # TODO: This should be smth specified in config in case user has # llama is too big, gotta use device map model = transformers.AutoModelForCausalLM.from_pretrained(self.name, **model_kwargs, device_map="balanced_low_0", cache_dir=self.cache_dir) self.device = 'cuda:1' @@ -174,12 +174,12 @@ def load_base_model_and_tokenizer(self, model_kwargs): tokenizer.add_special_tokens({'pad_token': '[PAD]'}) return model, tokenizer - + def load_model_properties(self): """ Load model properties, such as max length and stride. """ - # TODO: getting max_length of input could be more generic + # TODO: getting max_length of input could be more generic if "silo" in self.name or "balanced" in self.name: self.max_length = self.model.model.seq_len elif hasattr(self.model.config, 'max_position_embeddings'): @@ -215,6 +215,14 @@ def __init__(self, config: ExperimentConfig, name: str): model_kwargs=base_model_kwargs) self.load_model_properties() + def load(self): + if "llama" not in self.name and "alpaca" not in self.name: + super().load() + + def unload(self): + if "llama" not in self.name and "alpaca" not in self.name: + super().unload() + class QuantileReferenceModel(Model): """ @@ -235,28 +243,6 @@ def __init__(self, config: ExperimentConfig, name: str): self.load_model_properties() -class EvalModel(Model): - """ - GPT-based detector that can distinguish between machine-generated and human-written text - """ - def __init__(self, config: ExperimentConfig): - super().__init__(config) - self.device = self.config.env_config.device_aux - self.name = 'roberta-base-openai-detector' - self.model = transformers.AutoModelForSequenceClassification.from_pretrained(self.name, cache_dir=self.cache_dir).to(self.device) - self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.name, cache_dir=self.cache_dir) - - @torch.no_grad() - def get_preds(self, data): - batch_size = self.config.batch_size - preds = [] - for batch in tqdm(range(len(data) // batch_size), desc="Evaluating fake"): - batch_fake = data[batch * batch_size:(batch + 1) * batch_size] - batch_fake = self.tokenizer(batch_fake, padding=True, truncation=True, max_length=512, return_tensors="pt").to(self.device) - preds.extend(self.model(**batch_fake).logits.softmax(-1)[:,0].tolist()) - return preds - - class LanguageModel(Model): """ Generic LM- used most often for target model diff --git a/run.py b/run.py index c9af896..39eb49c 100644 --- a/run.py +++ b/run.py @@ -10,7 +10,7 @@ import pickle import math from collections import defaultdict -from typing import List +from typing import List, Dict from simple_parsing import ArgumentParser from pathlib import Path @@ -26,8 +26,8 @@ import mimir.data_utils as data_utils import mimir.plot_utils as plot_utils from mimir.utils import fix_seed -from mimir.models import EvalModel, LanguageModel, ReferenceModel, OpenAI_APIModel -from mimir.attacks.blackbox_attacks import BlackBoxAttacks +from mimir.models import LanguageModel, ReferenceModel, OpenAI_APIModel +from mimir.attacks.blackbox_attacks import BlackBoxAttacks, Attack from mimir.attacks.utils import get_attacker from mimir.attacks.neighborhood import T5Model, BertModel, NeighborhoodAttack from mimir.attacks.attack_utils import ( @@ -37,28 +37,14 @@ get_auc_from_thresholds, ) -# TODO: Might make more sense to have this function called once each for mem and non-mem, instead of handling all of them in a loop inside here -# For instance, makes it easy to know exact source of data -def run_blackbox_attacks( - data, - ds_objects, + +def get_attackers( target_model, ref_models, config: ExperimentConfig, - n_samples: int = None, - batch_size: int = 50, - keys_care_about: List[str] = ["nonmember", "member"], - scores_not_needed: bool = False, ): - fix_seed(config.random_seed) - - n_samples = len(data["nonmember"]) if n_samples is None else n_samples - - # Structure: attack -> member scores/nonmember scores - # For both members and nonmembers, we compute all attacks - # listed in config all together for each sample + # Look at all attacks, and attacks that we have implemented attacks = config.blackbox_attacks - neigh_config = config.neighborhood_config implemented_blackbox_attacks = [a.value for a in BlackBoxAttacks] # check for unimplemented attacks runnable_attacks = [] @@ -66,149 +52,177 @@ def run_blackbox_attacks( if a not in implemented_blackbox_attacks: print(f"Attack {a} not implemented, will be ignored") pass - runnable_attacks.append(a) attacks = runnable_attacks - if neigh_config: - n_perturbation_list = neigh_config.n_perturbation_list - in_place_swap = neigh_config.original_tokenization_swap - # Initialize attackers attackers = {} for attack in attacks: if attack != BlackBoxAttacks.REFERENCE_BASED: attackers[attack] = get_attacker(attack)(config, target_model) - attackers[attack].prepare() - results = defaultdict(list) - for classification in keys_care_about: - print(f"Running for classification {classification}") + # Initialize reference-based attackers + for name, ref_model in ref_models.items(): + attacker = get_attacker(BlackBoxAttacks.REFERENCE_BASED)( + config, target_model, ref_model + ) + attackers[f"{BlackBoxAttacks.REFERENCE_BASED}-{name.split('/')[-1]}"] = attacker + return attackers - neighbors = None - if BlackBoxAttacks.NEIGHBOR in attacks and neigh_config.load_from_cache: - neighbors = data[f"{classification}_neighbors"] - print("Loaded neighbors from cache!") - collected_neighbors = { - n_perturbation: [] for n_perturbation in n_perturbation_list - } +def get_mia_scores( + data, + attackers_dict: Dict[str, Attack], + ds_object, + target_model: LanguageModel, + ref_models: Dict[str, ReferenceModel], + config: ExperimentConfig, + is_train: bool, + n_samples: int = None, + batch_size: int = 50, +): + # Fix randomness + fix_seed(config.random_seed) - # For each batch of data - # TODO: Batch-size isn't really "batching" data - change later - for batch in tqdm( - range(math.ceil(n_samples / batch_size)), desc=f"Computing criterion" - ): - texts = data[classification][batch * batch_size : (batch + 1) * batch_size] - - # For each entry in batch - for idx in range(len(texts)): - sample_information = defaultdict(list) - sample = ( - texts[idx][: config.max_substrs] - if config.full_doc - else [texts[idx]] - ) + n_samples = len(data["records"]) if n_samples is None else n_samples - # This will be a list of integers if pretokenized - sample_information["sample"] = sample - if config.pretokenized: - detokenized_sample = [ - target_model.tokenizer.decode(s) for s in sample - ] - sample_information["detokenized"] = detokenized_sample - - # For each substring - neighbors_within = { - n_perturbation: [] for n_perturbation in n_perturbation_list - } - for i, substr in enumerate(sample): - # compute token probabilities for sample - s_tk_probs = ( - target_model.get_probabilities(substr) - if not config.pretokenized - else target_model.get_probabilities( - detokenized_sample[i], tokens=substr - ) - ) + # Look at all attacks, and attacks that we have implemented + neigh_config = config.neighborhood_config - # Always compute LOSS score. Also helpful for reference-based and many other attacks. - loss = ( - target_model.get_ll(substr, probs=s_tk_probs) - if not config.pretokenized - else target_model.get_ll( - detokenized_sample[i], tokens=substr, probs=s_tk_probs - ) - ) - sample_information[BlackBoxAttacks.LOSS].append(loss) - - # TODO: Shift functionality into each attack entirely, so that this is just a for loop - # For each attack - for attack in attacks: - # LOSS already added above, Reference handled later - if attack == BlackBoxAttacks.REFERENCE_BASED or BlackBoxAttacks.LOSS: - continue - - attacker = attackers[attack] - if attack != BlackBoxAttacks.NEIGHBOR: - score = attacker.attack(substr, - probs=s_tk_probs, - detokenized_sample=detokenized_sample[i] if config.pretokenized else None, - loss=loss) - sample_information[attack].append(score) - - if attack == BlackBoxAttacks.NEIGHBOR: - # For each 'number of neighbors' - for n_perturbation in n_perturbation_list: - # Use neighbors if available - if neighbors: - substr_neighbors = neighbors[n_perturbation][ - batch * batch_size + idx - ][i] - else: - substr_neighbors = attacker.get_neighbors( - [substr], n_perturbations=n_perturbation - ) - # Collect this neighbor information if neigh_config.dump_cache is True - if neigh_config.dump_cache: - neighbors_within[n_perturbation].append( - substr_neighbors - ) - - if not neigh_config.dump_cache: - # Only evaluate neighborhood attack when not caching neighbors - mean_substr_score = target_model.get_lls( - substr_neighbors, batch_size=4 - ) - d_based_score = loss - mean_substr_score + if neigh_config: + n_perturbation_list = neigh_config.n_perturbation_list + in_place_swap = neigh_config.original_tokenization_swap - sample_information[ - f"{attack}-{n_perturbation}" - ].append(d_based_score) + results = [] + # FROM HERE + neighbors = None + if BlackBoxAttacks.NEIGHBOR in attackers_dict.keys() and neigh_config.load_from_cache: + neighbors = data[f"neighbors"] + print("Loaded neighbors from cache!") - if neigh_config and neigh_config.dump_cache: - for n_perturbation in n_perturbation_list: - collected_neighbors[n_perturbation].append( - neighbors_within[n_perturbation] - ) + collected_neighbors = { + n_perturbation: [] for n_perturbation in n_perturbation_list + } - # Add the scores we collected for each sample for each - # attack into to respective list for its classification - results[classification].append(sample_information) + # For each batch of data + # TODO: Batch-size isn't really "batching" data - change later + for batch in tqdm(range(math.ceil(n_samples / batch_size)), desc=f"Computing criterion"): + texts = data["records"][batch * batch_size : (batch + 1) * batch_size] + + # For each entry in batch + for idx in range(len(texts)): + sample_information = defaultdict(list) + sample = ( + texts[idx][: config.max_substrs] + if config.full_doc + else [texts[idx]] + ) - if neigh_config and neigh_config.dump_cache: - ds_obj_use = ds_objects[classification] + # This will be a list of integers if pretokenized + sample_information["sample"] = sample + if config.pretokenized: + detokenized_sample = [target_model.tokenizer.decode(s) for s in sample] + sample_information["detokenized"] = detokenized_sample + + # For each substring + neighbors_within = {n_perturbation: [] for n_perturbation in n_perturbation_list} + for i, substr in enumerate(sample): + # compute token probabilities for sample + s_tk_probs = ( + target_model.get_probabilities(substr) + if not config.pretokenized + else target_model.get_probabilities( + detokenized_sample[i], tokens=substr + ) + ) - # Save p_member_text and p_nonmember_text (Lists of strings) to cache - # For each perturbation - for n_perturbation in n_perturbation_list: - ds_obj_use.dump_neighbors( - collected_neighbors[n_perturbation], - train=True if classification == "member" else False, - num_neighbors=n_perturbation, - model=neigh_config.model, - in_place_swap=in_place_swap, + # Always compute LOSS score. Also helpful for reference-based and many other attacks. + loss = ( + target_model.get_ll(substr, probs=s_tk_probs) + if not config.pretokenized + else target_model.get_ll( + detokenized_sample[i], tokens=substr, probs=s_tk_probs + ) ) + sample_information[BlackBoxAttacks.LOSS].append(loss) + + # TODO: Shift functionality into each attack entirely, so that this is just a for loop + # For each attack + for attack, attacker in attackers_dict.items(): + # LOSS already added above, Reference handled later + if attack.startswith(BlackBoxAttacks.REFERENCE_BASED) or attack == BlackBoxAttacks.LOSS: + continue + + if attack != BlackBoxAttacks.NEIGHBOR: + score = attacker.attack( + substr, + probs=s_tk_probs, + detokenized_sample=( + detokenized_sample[i] + if config.pretokenized + else None + ), + loss=loss, + ) + sample_information[attack].append(score) + else: + # For each 'number of neighbors' + for n_perturbation in n_perturbation_list: + # Use neighbors if available + if neighbors: + substr_neighbors = neighbors[n_perturbation][ + batch * batch_size + idx + ][i] + else: + substr_neighbors = attacker.get_neighbors( + [substr], n_perturbations=n_perturbation + ) + # Collect this neighbor information if neigh_config.dump_cache is True + if neigh_config.dump_cache: + neighbors_within[n_perturbation].append( + substr_neighbors + ) + + if not neigh_config.dump_cache: + # Only evaluate neighborhood attack when not caching neighbors + score = attacker.attack( + substr, + probs=s_tk_probs, + detokenized_sample=( + detokenized_sample[i] + if config.pretokenized + else None + ), + loss=loss, + batch_siz=4, + substr_neighbors=substr_neighbors, + ) + + sample_information[ + f"{attack}-{n_perturbation}" + ].append(score) + + if neigh_config and neigh_config.dump_cache: + for n_perturbation in n_perturbation_list: + collected_neighbors[n_perturbation].append( + neighbors_within[n_perturbation] + ) + + # Add the scores we collected for each sample for each + # attack into to respective list for its classification + results.append(sample_information) + + if neigh_config and neigh_config.dump_cache: + # Save p_member_text and p_nonmember_text (Lists of strings) to cache + # For each perturbation + for n_perturbation in n_perturbation_list: + ds_object.dump_neighbors( + collected_neighbors[n_perturbation], + train=is_train, + num_neighbors=n_perturbation, + model=neigh_config.model, + in_place_swap=in_place_swap, + ) if neigh_config and neigh_config.dump_cache: print( @@ -217,51 +231,62 @@ def run_blackbox_attacks( exit(0) # Perform reference-based attacks - if BlackBoxAttacks.REFERENCE_BASED in attacks and ref_models is not None: + if ref_models is not None: for name, ref_model in ref_models.items(): - if "llama" not in name and "alpaca" not in name: - ref_model.load() - - # attacker = get_attacker(BlackBoxAttacks.REFERENCE_BASED)(config, target_model, ref_model) + ref_key = f"{BlackBoxAttacks.REFERENCE_BASED}-{name.split('/')[-1]}" + attacker = attackers_dict.get(ref_key, None) + if attacker is None: + continue # Update collected scores for each sample with ref-based attack scores - for classification, result in results.items(): - for r in tqdm(result, desc="Ref scores"): - ref_model_scores = [] - for i, s in enumerate(r["sample"]): - if config.pretokenized: - s = r["detokenized"][i] - ref_score = r[BlackBoxAttacks.LOSS][i] - ref_model.get_ll(s) - ref_model_scores.append(ref_score) - r[ - f"{BlackBoxAttacks.REFERENCE_BASED}-{name.split('/')[-1]}" - ].extend(ref_model_scores) - - if "llama" not in name and "alpaca" not in name: - ref_model.unload() + for r in tqdm(results, desc="Ref scores"): + ref_model_scores = [] + for i, s in enumerate(r["sample"]): + if config.pretokenized: + s = r["detokenized"][i] + score = attacker.attack(s, probs=None, + loss=r[BlackBoxAttacks.LOSS][i]) + ref_model_scores.append(score) + r[ref_key].extend(ref_model_scores) + + attacker.unload() else: print("No reference models specified, skipping Reference-based attacks") # Rearrange the nesting of the results dict and calculated aggregated score for sample # attack -> member/nonmember -> list of scores - samples = defaultdict(list) - predictions = defaultdict(lambda: defaultdict(list)) - for classification, result in results.items(): - for r in result: - samples[classification].append(r["sample"]) - for attack, scores in r.items(): - if attack != "sample" and attack != "detokenized": - predictions[attack][classification].append(np.min(scores)) - - if scores_not_needed: - return predictions + samples = [] + predictions = defaultdict(lambda: []) + for r in results: + samples.append(r["sample"]) + for attack, scores in r.items(): + if attack != "sample" and attack != "detokenized": + # TODO: Is there a reason for the np.min here? + predictions[attack].append(np.min(scores)) + + return predictions, samples + + +def compute_metrics_from_scores( + preds_member: dict, + preds_nonmember: dict, + samples_member: List, + samples_nonmember: List, + n_samples: int): + + attack_keys = list(preds_member.keys()) + if attack_keys != list(preds_nonmember.keys()): + raise ValueError("Mismatched attack keys for member/nonmember predictions") # Collect outputs for each attack blackbox_attack_outputs = {} - for attack, prediction in tqdm(predictions.items()): + for attack in attack_keys: + preds_member_ = preds_member[attack] + preds_nonmember_ = preds_nonmember[attack] + fpr, tpr, roc_auc, roc_auc_res, thresholds = get_roc_metrics( - preds_member=prediction["member"], - preds_nonmember=prediction["nonmember"], + preds_member=preds_member_, + preds_nonmember=preds_nonmember_, perform_bootstrap=True, return_thresholds=True, ) @@ -270,7 +295,8 @@ def run_blackbox_attacks( for upper_bound in config.fpr_list } p, r, pr_auc = get_precision_recall_metrics( - preds_member=prediction["member"], preds_nonmember=prediction["nonmember"] + preds_member=preds_member_, + preds_nonmember=preds_nonmember_ ) print( @@ -278,11 +304,18 @@ def run_blackbox_attacks( ) blackbox_attack_outputs[attack] = { "name": f"{attack}_threshold", - "predictions": prediction, + "predictions": { + "member": preds_member_, + "nonmember": preds_nonmember_, + }, "info": { "n_samples": n_samples, }, - "raw_results": samples if not config.pretokenized else [], + "raw_results": ( + {"member": samples_member, "nonmember": samples_nonmember} + if not config.pretokenized + else [] + ), "metrics": { "roc_auc": roc_auc, "fpr": fpr, @@ -392,61 +425,6 @@ def generate_data( # return generate_samples(data[:n_samples], batch_size=batch_size) -def eval_supervised(data, model): - print(f"Beginning supervised evaluation with {model}...") - - real, fake = data["nonmember"], data["member"] - - # TODO: Fix init call below - eval_model = EvalModel(model) - - real_preds = eval_model.get_preds(real) - fake_preds = eval_model.get_preds(fake) - - predictions = { - "real": real_preds, - "samples": fake_preds, - } - - fpr, tpr, roc_auc, roc_auc_res = get_roc_metrics( - preds_member=real_preds, preds_nonmember=fake_preds, perform_bootstrap=True - ) - tpr_at_low_fpr = { - upper_bound: tpr[np.where(np.array(fpr) < upper_bound)[0][-1]] - for upper_bound in config.fpr_list - } - p, r, pr_auc = get_precision_recall_metrics( - preds_member=real_preds, preds_nonmember=fake_preds - ) - print(f"{model} ROC AUC: {roc_auc}, PR AUC: {pr_auc}") - - del eval_model - # Clear CUDA cache - torch.cuda.empty_cache() - - return { - "name": model, - "predictions": predictions, - "info": { - "n_samples": n_samples, - }, - "metrics": { - "roc_auc": roc_auc, - "fpr": fpr, - "tpr": tpr, - "bootstrap_roc_auc_mean": np.mean(roc_auc_res.bootstrap_distribution), - "bootstrap_roc_auc_std": roc_auc_res.standard_error, - "tpr_at_low_fpr": tpr_at_low_fpr, - }, - "pr_metrics": { - "pr_auc": pr_auc, - "precision": p, - "recall": r, - }, - "loss": 1 - pr_auc, - } - - def main(config: ExperimentConfig): env_config: EnvironmentConfig = config.env_config neigh_config: NeighborhoodConfig = config.neighborhood_config @@ -714,7 +692,7 @@ def main(config: ExperimentConfig): ) json.dump(seq_lens, f) - # Remove below if not needed/used + # TODO: Remove below if not needed/used """ tk_freq_map = None if config.token_frequency_map is not None: @@ -722,82 +700,89 @@ def main(config: ExperimentConfig): tk_freq_map = pickle.load(open(config.token_frequency_map, "rb")) """ - # Add neighborhood-related data entries to 'data' - data["nonmember_neighbors"] = neighbors_nonmember - data["member_neighbors"] = neighbors_member - - ds_objects = {"nonmember": data_obj_nonmem, "member": data_obj_mem} + # TODO: Instead of extracting from 'data', construct directly somewhere above + data_members = { + "records": data["member"], + "neighbors": neighbors_member, + } + data_nonmembers = { + "records": data["nonmember"], + "neighbors": neighbors_nonmember, + } outputs = [] - if config.blackbox_attacks is not None: - # perform blackbox attacks - blackbox_outputs = run_blackbox_attacks( - data, - ds_objects=ds_objects, - target_model=base_model, - ref_models=ref_models, - config=config, - n_samples=n_samples, - ) + if config.blackbox_attacks is None: + raise ValueError("No blackbox attacks specified in config!") + + # Prepare attackers + attackers_dict = get_attackers(base_model, ref_models, config) + + # Collect scores for members + member_preds, member_samples = get_mia_scores( + data_members, + attackers_dict, + data_obj_mem, + target_model=base_model, + ref_models=ref_models, + config=config, + is_train=True, + n_samples=n_samples + ) + # Collect scores for non-members + nonmember_preds, nonmember_samples = get_mia_scores( + data_nonmembers, + attackers_dict, + data_obj_nonmem, + target_model=base_model, + ref_models=ref_models, + config=config, + is_train=False, + n_samples=n_samples, + ) + blackbox_outputs = compute_metrics_from_scores( + member_preds, + nonmember_preds, + member_samples, + nonmember_samples, + n_samples=n_samples, + ) - # TODO: For now, AUCs for other sources of non-members are only printed (not saved) - # Will fix later! - if config.dataset_nonmember_other_sources is not None: - # Using thresholds returned in blackbox_outputs, compute AUCs and ROC curves for other non-member sources - for other_obj, other_nonmember, other_name in zip( - other_objs, other_nonmembers, config.dataset_nonmember_other_sources - ): - # other_data, _, other_n_samples = generate_data_processed( - # other_nonmember, batch_size=config.batch_size - # ) - other_ds_objects = {"nonmember": other_obj} - other_blackbox_predictions = run_blackbox_attacks( - data={"nonmember": other_nonmember}, - ds_objects=other_ds_objects, - target_model=base_model, - ref_models=ref_models, - config=config, - n_samples=n_samples, - keys_care_about=["nonmember"], - scores_not_needed=True, - ) + # TODO: For now, AUCs for other sources of non-members are only printed (not saved) + # Will fix later! + if config.dataset_nonmember_other_sources is not None: + # Using thresholds returned in blackbox_outputs, compute AUCs and ROC curves for other non-member sources + for other_obj, other_nonmember, other_name in zip( + other_objs, other_nonmembers, config.dataset_nonmember_other_sources + ): + other_nonmem_preds, _ = get_mia_scores( + other_nonmember, + attackers_dict, + other_obj, + target_model=base_model, + ref_models=ref_models, + config=config, + is_train=False, + n_samples=n_samples, + ) - for attack in blackbox_outputs.keys(): - member_scores = np.array( - blackbox_outputs[attack]["predictions"]["member"] - ) - thresholds = blackbox_outputs[attack]["metrics"]["thresholds"] - nonmember_scores = np.array( - other_blackbox_predictions[attack]["nonmember"] - ) - auc = get_auc_from_thresholds( - member_scores, nonmember_scores, thresholds - ) - print( - f"AUC using thresholds of original split on {other_name} using {attack}: {auc}" - ) - exit(0) + for attack in blackbox_outputs.keys(): + member_scores = np.array( + member_preds[attack]["predictions"]["member"] + ) + thresholds = blackbox_outputs[attack]["metrics"]["thresholds"] + nonmember_scores = np.array(other_nonmem_preds[attack]) + auc = get_auc_from_thresholds( + member_scores, nonmember_scores, thresholds + ) + print( + f"AUC using thresholds of original split on {other_name} using {attack}: {auc}" + ) + exit(0) - # TODO: Skipping openai-detector (for now) - # if config.max_tokens < 512: - # baseline_outputs.append(eval_supervised(data, model='roberta-base-openai-detector')) - # baseline_outputs.append(eval_supervised(data, model='roberta-large-openai-detector')) - - for attack, output in blackbox_outputs.items(): - outputs.append(output) - with open(os.path.join(SAVE_FOLDER, f"{attack}_results.json"), "w") as f: - json.dump(output, f) - - # Skipping openai-detector (for now) - # write supervised results to a file - # TODO: update to read from baseline result dict - # if config.max_tokens < 512: - # with open(os.path.join(SAVE_FOLDER, f"roberta-base-openai-detector_results.json"), "w") as f: - # json.dump(baseline_outputs[-2], f) - - # # write supervised results to a file - # with open(os.path.join(SAVE_FOLDER, f"roberta-large-openai-detector_results.json"), "w") as f: - # json.dump(baseline_outputs[-1], f) + for attack, output in blackbox_outputs.items(): + outputs.append(output) + with open(os.path.join(SAVE_FOLDER, f"{attack}_results.json"), "w") as f: + json.dump(output, f) neighbor_model_name = neigh_config.model if neigh_config else None # TODO: Fix TPR/FPR computation