From cc41fb4d290ef88f4bb24d8e9e11f801f04d2ad4 Mon Sep 17 00:00:00 2001 From: Anshuman Date: Wed, 20 Nov 2024 23:55:33 -0500 Subject: [PATCH] Add DC-PDD --- mimir/attacks/all_attacks.py | 15 +++-- mimir/attacks/dc_pdd.py | 123 +++++++++++++++++++++++++++++++++++ mimir/attacks/utils.py | 4 +- mimir/models.py | 4 +- setup.py | 2 +- 5 files changed, 137 insertions(+), 11 deletions(-) create mode 100644 mimir/attacks/dc_pdd.py diff --git a/mimir/attacks/all_attacks.py b/mimir/attacks/all_attacks.py index 4692695..7d35398 100644 --- a/mimir/attacks/all_attacks.py +++ b/mimir/attacks/all_attacks.py @@ -8,14 +8,15 @@ # Attack definitions class AllAttacks(str, Enum): - LOSS = "loss" # Done - REFERENCE_BASED = "ref" # Done - ZLIB = "zlib" # Done - MIN_K = "min_k" # Done - MIN_K_PLUS_PLUS = "min_k++" # Done - NEIGHBOR = "ne" # Done - GRADNORM = "gradnorm" # Done + LOSS = "loss" + REFERENCE_BASED = "ref" + ZLIB = "zlib" + MIN_K = "min_k" + MIN_K_PLUS_PLUS = "min_k++" + NEIGHBOR = "ne" + GRADNORM = "gradnorm" RECALL = "recall" + DC_PDD = "dc_pdd" # QUANTILE = "quantile" # Uncomment when tested implementation is available diff --git a/mimir/attacks/dc_pdd.py b/mimir/attacks/dc_pdd.py new file mode 100644 index 0000000..8b572fc --- /dev/null +++ b/mimir/attacks/dc_pdd.py @@ -0,0 +1,123 @@ +""" + DC-PDD Attack: https://aclanthology.org/2024.emnlp-main.300/ + Based on the official implementation: https://github.com/zhang-wei-chao/DC-PDD +""" +import torch as ch +from tqdm import tqdm +import numpy as np +import requests +import io +import gzip +import os +import json +from mimir.attacks.all_attacks import Attack +from mimir.models import Model +from mimir.config import ExperimentConfig +from mimir.utils import get_cache_path + + +def ensure_parent_directory_exists(filename): + # Get the parent directory from the given filename + parent_dir = os.path.dirname(filename) + + # Create the parent directory if it does not exist + if parent_dir and not os.path.exists(parent_dir): + os.makedirs(parent_dir, exist_ok=True) + + +class DC_PDDAttack(Attack): + + def __init__(self, config: ExperimentConfig, model: Model): + super().__init__(config, model, ref_model=None) + # Use subset of C-4 + self.fre_dis = ch.zeros(model.tokenizer.vocab_size) + # Account for model name + model_name = model.name + + # Load from cache if available, save otherwise + cached_file_path = os.path.join(get_cache_path(), "DC_PDD_freq_dis", "C4", f"{model_name}.pt") + + if os.path.exists(cached_file_path): + self.fre_dis = ch.load(cached_file_path) + print(f"Loaded frequency distribution from cache for {model_name}") + else: + # Make sure the directory exists + ensure_parent_directory_exists(cached_file_path) + # Collect frequency data + self._collect_frequency_data() + ch.save(self.fre_dis, cached_file_path) + print(f"Saved frequency distribution to cache for {model_name}") + + # Laplace smoothing + self.fre_dis = (1 + self.fre_dis) / (ch.sum(self.fre_dis) + len(self.fre_dis)) + + def _fre_dis(self, ref_data, max_tok: int = 1024): + """ + token frequency distribution + ref_data: reference dataset + tok: tokenizer + """ + # Tokenize all the text in the reference dataset + # input_ids = self.target_model.tokenizer(ref_data, truncation=True, max_length=max_tok).input_ids + for text in tqdm(ref_data): + input_ids = self.target_model.tokenizer(text, truncation=True, max_length=max_tok).input_ids + self.fre_dis[input_ids] += 1 + + def _collect_frequency_data(self, fil_num: int = 15): + for i in tqdm(range(fil_num), desc="Downloading and processing dataset"): + # Download the dataset split + url = f"https://huggingface.co/datasets/allenai/c4/resolve/main/en/c4-train.{"{:05}".format(i)}-of-01024.json.gz" + # Download the file + response = requests.get(url) + response.raise_for_status() # Check for download errors + + # Open and parse the .json.gz file - the file is a .json file with one json object per line + with gzip.GzipFile(fileobj=io.BytesIO(response.content)) as gz_file: + sub_dataset = gz_file.readlines() + examples = [] + # for example in tqdm(sub_dataset): + for example in sub_dataset: + example = json.loads(example) + examples.append(example['text']) + + # Compute the frequency distribution + self._fre_dis(examples) + + @ch.no_grad() + def _attack(self, document, probs, tokens=None, **kwargs): + """ + DC-PDD Attack: Use frequency distribution of some large corpus to "calibrate" token probabilities + and compute a membership score. + """ + # Hyper-params specific to DC-PDD + a: float = kwargs.get("a", 0.01) + + # Tokenize text (we process things slightly differently) + tokens_og = self.target_model.tokenizer(document, return_tensors="pt").input_ids + # Inject EOS token at beginning + tokens = ch.cat([ch.tensor([[self.target_model.tokenizer.eos_token_id]]), tokens_og], dim=1).numpy() + + # these are all log probabilites + probs_with_start_token = self.target_model.get_probabilities(document, tokens=tokens) + x_pro = np.exp(probs_with_start_token) + + indexes = [] + current_ids = [] + input_ids = tokens_og[0] + for i, input_id in enumerate(input_ids): + if input_id not in current_ids: + indexes.append(i) + current_ids.append(input_id) + + x_pro = x_pro[indexes] + x_fre = self.fre_dis[input_ids[indexes]].numpy() + + # Compute alpha values + alpha = x_pro * np.log(1 / x_fre) + + # Compute membership score + alpha[alpha > a] = a + + beta = - np.mean(alpha) + + return beta diff --git a/mimir/attacks/utils.py b/mimir/attacks/utils.py index 766e22b..c97ac3d 100644 --- a/mimir/attacks/utils.py +++ b/mimir/attacks/utils.py @@ -8,6 +8,7 @@ from mimir.attacks.neighborhood import NeighborhoodAttack from mimir.attacks.gradnorm import GradNormAttack from mimir.attacks.recall import ReCaLLAttack +from mimir.attacks.dc_pdd import DC_PDDAttack # TODO Use decorators to link attack implementations with enum above @@ -20,7 +21,8 @@ def get_attacker(attack: str): AllAttacks.MIN_K_PLUS_PLUS: MinKPlusPlusAttack, AllAttacks.NEIGHBOR: NeighborhoodAttack, AllAttacks.GRADNORM: GradNormAttack, - AllAttacks.RECALL: ReCaLLAttack + AllAttacks.RECALL: ReCaLLAttack, + AllAttacks.DC_PDD: DC_PDDAttack } attack_cls = mapping.get(attack, None) if attack_cls is None: diff --git a/mimir/models.py b/mimir/models.py index 11b4c2e..c80347b 100644 --- a/mimir/models.py +++ b/mimir/models.py @@ -78,6 +78,7 @@ def get_probabilities(self, text (str): The input text for which to calculate probabilities. tokens (numpy.ndarray, optional): An optional array of token ids. If provided, these tokens are used instead of tokenizing the input text. Defaults to None. + return_all_probs: bool: If True, return all token probabilities. Defaults to False. Raises: ValueError: If the device or name attributes of the instance are not set. @@ -95,8 +96,7 @@ def get_probabilities(self, # expand first dimension labels = labels.unsqueeze(0) else: - tokenized = self.tokenizer( - text, return_tensors="pt") + tokenized = self.tokenizer(text, return_tensors="pt") labels = tokenized.input_ids target_token_log_prob = [] diff --git a/setup.py b/setup.py index 00bf06f..4065690 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ description="Python package for measuring memorization in LLMs", author="Anshuman Suri, Michael Duan, Niloofar Mireshghallah", author_email="as9rw@virginia.edu", - version="1.0", + version="1.1", url="https://github.com/iamgroot42/mimir", license="MIT", python_requires=">=3.9",