diff --git a/README.md b/README.md index 6dd7e41..1c89c95 100644 --- a/README.md +++ b/README.md @@ -58,3 +58,16 @@ To add an attack, create a file for your attack (e.g. `attacks/my_attack.py`) an Then, add a name for your attack to the dictionary in `attacks/blackbox_attack.py`. If you would like to submit your attack to the repository, please open a pull request describing your attack and the paper it is based on. + +## Citation + +If you use MIMIR in your research, please cite our paper: + +```bibtex +@article{duan2024membership, + title={Do Membership Inference Attacks Work on Large Language Models?}, + author={Michael Duan and Anshuman Suri and Niloofar Mireshghallah and Sewon Min and Weijia Shi and Luke Zettlemoyer and Yulia Tsvetkov and Yejin Choi and David Evans and Hannaneh Hajishirzi}, + year={2024}, + journal={arXiv:2402.07841}, +} +``` \ No newline at end of file diff --git a/mimir/config.py b/mimir/config.py index 0e5bac6..13d5d57 100644 --- a/mimir/config.py +++ b/mimir/config.py @@ -202,14 +202,17 @@ class ExperimentConfig(Serializable): """OpenAI config""" def __post_init__(self): - if self.dump_cache and self.load_from_cache: + if self.dump_cache and (self.load_from_cache or self.load_from_hf): raise ValueError("Cannot dump and load cache at the same time") if self.neighborhood_config: if ( self.neighborhood_config.dump_cache or self.neighborhood_config.load_from_cache - ) and not (self.load_from_cache or self.dump_cache): + ) and not (self.load_from_cache or self.dump_cache or self.load_from_hf): raise ValueError( "Using dump/load for neighborhood cache without dumping/loading main cache does not make sense" ) + + if self.neighborhood_config.dump_cache and (self.neighborhood_config.load_from_cache or self.load_from_hf): + raise ValueError("Cannot dump and load neighborhood cache at the same time") diff --git a/mimir/custom_datasets.py b/mimir/custom_datasets.py index 1735340..29ce84e 100644 --- a/mimir/custom_datasets.py +++ b/mimir/custom_datasets.py @@ -38,10 +38,47 @@ def load_cached(cache_dir, print("Loading from HuggingFace!") data_split = data_split.replace("train", "member") data_split = data_split.replace("test", "nonmember") - ds = datasets.load_dataset("iamgroot42/mimir", name=filename, split=data_split) - data = collect_hf_data(ds) - if len(data) != n_samples: - raise ValueError(f"Requested {n_samples} samples, but only {len(data)} samples available. Potential mismatch in HuggingFace data and requested data.") + if not filename.startswith("the_pile"): + raise ValueError(f"HuggingFace data only available for The Pile.") + + SOURCES_UPLOADED = [ + "arxiv", + "dm_mathematics", + "github", + "hackernews", + "pile_cc", + "pubmed_central", + "wikipedia_(en)", + "full_pile", + "c4", + "temporal_arxiv", + "temporal_wiki" + ] + + for source in SOURCES_UPLOADED: + # Got a match + if source in filename and filename.startswith(f"the_pile_{source}"): + split = filename.split(f"the_pile_{source}")[1] + if split == "": + # The way HF data is uploaded, no split is recorded as "none" + split = "none" + else: + # remove the first underscore + split = split[1:] + # remove '<' , '>' + split = split.replace("<", "").replace(">", "") + # Remove "_truncated" from the end, if present + split = split.rsplit("_truncated", 1)[0] + + # Load corresponding dataset + ds = datasets.load_dataset("iamgroot42/mimir", name=source, split=split) + data = ds[data_split] + # Check if the number of samples is correct + if len(data) != n_samples: + raise ValueError(f"Requested {n_samples} samples, but only {len(data)} samples available. Potential mismatch in HuggingFace data and requested data.") + return data + # If got here, matching source was not found + raise ValueError(f"Requested source {filename} not found in HuggingFace data.") else: file_path = os.path.join(cache_dir, f"cache_{min_length}_{max_length}_{n_samples}_{max_tokens}", data_split, filename + ".jsonl") if not os.path.exists(file_path): @@ -50,18 +87,6 @@ def load_cached(cache_dir, return data -def collect_hf_data(ds): - """ - Helper function to collect all data from a given HuggingFace dataset split. - """ - records = [x["text"] for x in ds] - # Standard DS - if len(records[0]) == 1: - records = [x[0] for x in records] - # Neighbor data - return records - - def load_data(file_path): """ Load data from a given filepath (.jsonl) diff --git a/mimir/data_utils.py b/mimir/data_utils.py index 0885758..2e66ac3 100644 --- a/mimir/data_utils.py +++ b/mimir/data_utils.py @@ -92,7 +92,7 @@ def load(self, train: bool, mask_tokenizer=None, specific_source: str = None): assert not self.config.full_doc data = np.load(self.presampled) return data - elif self.config.load_from_cache: + elif (self.config.load_from_cache or self.config.load_from_hf): # Load from cache, if requested filename = self._get_name_to_save() data = custom_datasets.load_cached( diff --git a/python_scripts/check_hf.py b/python_scripts/check_hf.py index 19e2226..0240bfa 100644 --- a/python_scripts/check_hf.py +++ b/python_scripts/check_hf.py @@ -1,6 +1,15 @@ from datasets import load_dataset -ds = load_dataset("mimir.py", "the_pile_full_pile") -print(ds) -ds = load_dataset("mimir.py", "the_pile_arxiv") -print(ds['member'][0]['text']) +ds = load_dataset("mimir.py", "pile_cc", split="ngram_7_0.2") +print(ds['member'][0]) +ds = load_dataset("mimir.py", "full_pile", split="none") +print(len(ds['member'])) +assert len(ds['member']) == 10000 +print(ds["nonmember_neighbors"][0]) +print(len(ds["member_neighbors"])) +print(ds['member_neighbors'][0][12]) +ds = load_dataset("mimir.py", "arxiv", split="ngram_13_0.8") +print(ds["nonmember_neighbors"][1][9]) + +assert len(ds['member']) == 1000 +assert len(ds["nonmember_neighbors"][0]) == 25 diff --git a/python_scripts/mimir.py b/python_scripts/mimir.py index 88f0d4b..8547afe 100644 --- a/python_scripts/mimir.py +++ b/python_scripts/mimir.py @@ -14,6 +14,8 @@ import datasets +from typing import List + _HOMEPAGE = "http://github.com/iamgroot42/mimir" @@ -23,11 +25,11 @@ """ _CITATION = """\ -@article{duan2024do, - title={Do Membership Inference Attacks Work on Large Language Models?}, - author={Duan*, Michael and \textbf{A. Suri*} and Mireshghallah, Niloofar and Min, Sewon and Shi, Weijia and Zettlemoyer, Luke and Tsvetkov, Yulia and Choi, Yejin and Evans, David and Hajishirzi, Hannaneh}, - journal={arXiv preprint arXiv:???}, - year={2024} +@article{duan2024membership, + title={Do Membership Inference Attacks Work on Large Language Models?}, + author={Michael Duan and Anshuman Suri and Niloofar Mireshghallah and Sewon Min and Weijia Shi and Luke Zettlemoyer and Yulia Tsvetkov and Yejin Choi and David Evans and Hannaneh Hajishirzi}, + year={2024}, + journal={arXiv:2402.07841}, } """ @@ -37,27 +39,71 @@ class MimirConfig(BuilderConfig): """BuilderConfig for Mimir dataset.""" - def __init__(self, **kwargs): + def __init__(self, *args, subsets: List[str]=[], **kwargs): """Constructs a MimirConfig. Args: **kwargs: keyword arguments forwarded to super. """ super(MimirConfig, self).__init__(**kwargs) + self.subsets = subsets class MimirDataset(GeneratorBasedBuilder): # Assuming 'VERSION' is defined - VERSION = datasets.Version("1.0.0") + VERSION = datasets.Version("1.3.0") # Define the builder configs BUILDER_CONFIG_CLASS = MimirConfig BUILDER_CONFIGS = [ MimirConfig( - name="the_pile_arxiv", description="This split contains data from Arxiv" + name="arxiv", + subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"], + description="This split contains data from the Pile's Arxiv subset at various n-gram overlap thresholds" + ), + MimirConfig( + name="dm_mathematics", + subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"], + description="This split contains data from the Pile's DM Mathematics subset at various n-gram overlap thresholds" + ), + MimirConfig( + name="github", + subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"], + description="This split contains data from the Pile's GitHub subset at various n-gram overlap thresholds" + ), + MimirConfig( + name="hackernews", + subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"], + description="This split contains data from the Pile's HackerNews subset at various n-gram overlap thresholds" + ), + MimirConfig( + name="pile_cc", + subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"], + description="This split contains data from the Pile's Pile CC subset at various n-gram overlap thresholds" + ), + MimirConfig( + name="pubmed_central", + subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"], + description="This split contains data from the Pile's PubMed Central subset at various n-gram overlap thresholds" + ), + MimirConfig( + name="wikipedia_(en)", + subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"], + description="This split contains data from the Pile's Wikipedia subset at various n-gram overlap thresholds" ), MimirConfig( - name="the_pile_full_pile", description="This split contains data from multiple sources in the Pile", + name="full_pile", description="This split contains data from multiple sources in the Pile", + ), + MimirConfig( + name="c4", description="This split contains data the C4 dataset", + ), + MimirConfig( + name="temporal_arxiv", + subsets=["2020_08", "2021_01", "2021_06", "2022_01", "2022_06", "2023_01", "2023_06"], + description="This split contains benchmarks where non-members are selected from various months from 2020-08 and onwards", + ), + MimirConfig( + name="temporal_wiki", description="This split contains benchmarks where non-members are selected from 2023-08 and onwards", ), ] @@ -66,9 +112,12 @@ def _info(self): # This is the description that will appear on the datasets page. description=_DESCRIPTION, # This defines the different columns of the dataset and their types - features=datasets.Features( - {"text": datasets.Sequence(datasets.Value("string"))} - ), + features=datasets.Features({ + "member": datasets.Value("string"), + "nonmember": datasets.Value("string"), + "member_neighbors": datasets.Sequence(datasets.Value("string")), + "nonmember_neighbors": datasets.Sequence(datasets.Value("string")) + }), # If there's a common (input, target) tuple from the features, # specify them here. They'll be used if as_supervised=True in # builder.as_dataset. @@ -76,7 +125,7 @@ def _info(self): # Homepage of the dataset for documentation homepage=_HOMEPAGE, # Citation for the dataset - # citation=_CITATION, + citation=_CITATION, ) def _split_generators(self, dl_manager: DownloadManager): @@ -85,50 +134,61 @@ def _split_generators(self, dl_manager: DownloadManager): NEIGHBOR_SUFFIX = "_neighbors_25_bert_in_place_swap" parent_dir = ( "cache_100_200_10000_512" - if self.config.name == "the_pile_full_pile" + if self.config.name == "full_pile" else "cache_100_200_1000_512" ) - file_paths = { - "member": os.path.join(parent_dir, "train", self.config.name + ".jsonl"), - "nonmember": os.path.join(parent_dir, "test", self.config.name + ".jsonl"), - } - # Load neighbor splits if they exist - # TODO: This is not correct (should be checking URL, not local file structure). Fix later - if os.path.exists( - os.path.join( - parent_dir, - "train_neighbors", - self.config.name + f"{NEIGHBOR_SUFFIX}.jsonl", - ) - ): - # Assume if train nieghbors exist, test neighbors also exist - file_paths["member_neighbors"] = os.path.join( + if len(self.config.subsets) > 0: + suffixes = [f"{subset}" for subset in self.config.subsets] + else: + suffixes = ["none"] + + file_paths = {} + for subset_split_suffix in suffixes: + internal_fp = {} + + subset_split_suffix_use = f"_{subset_split_suffix}" if subset_split_suffix != "none" else "" + + # Add standard member and non-member paths + internal_fp['member'] = os.path.join(parent_dir, "train", f"{self.config.name}{subset_split_suffix_use}.jsonl") + internal_fp['nonmember'] = os.path.join(parent_dir, "test", f"{self.config.name}{subset_split_suffix_use}.jsonl") + + # Load associated neighbors + internal_fp['member_neighbors'] = os.path.join( parent_dir, "train_neighbors", - self.config.name + f"{NEIGHBOR_SUFFIX}.jsonl", + f"{self.config.name}{subset_split_suffix_use}{NEIGHBOR_SUFFIX}.jsonl", ) - file_paths["nonmember_neighbors"] = os.path.join( + internal_fp['nonmember_neighbors'] = os.path.join( parent_dir, "test_neighbors", - self.config.name + f"{NEIGHBOR_SUFFIX}.jsonl", + f"{self.config.name}{subset_split_suffix_use}{NEIGHBOR_SUFFIX}.jsonl", ) + file_paths[subset_split_suffix] = internal_fp # Now that we know which files to load, download them - download_paths = [_DOWNLOAD_URL + v for v in file_paths.values()] - data_dir = dl_manager.download_and_extract(download_paths) + data_dir = {} + for k, v_dict in file_paths.items(): + download_paths = [] + for v in v_dict.values(): + download_paths.append(_DOWNLOAD_URL + v) + paths = dl_manager.download_and_extract(download_paths) + internal_dict = {k:v for k, v in zip(v_dict.keys(), paths)} + data_dir[k] = internal_dict splits = [] - for i, k in enumerate(file_paths.keys()): - splits.append(SplitGenerator(name=k, gen_kwargs={"file_path": data_dir[i]})) + for k in suffixes: + splits.append(SplitGenerator(name=k, gen_kwargs={"file_path_dict": data_dir[k]})) return splits - def _generate_examples(self, file_path): + def _generate_examples(self, file_path_dict): """Yields examples.""" - # Open the specified .jsonl file and read each line - with open(file_path, "r") as f: - for id, line in enumerate(f): - data = json.loads(line) - if type(data) != list: - data = [data] - yield id, {"text": data} + # Open all four files in file_path_dict and yield examples (one from each file) simultaneously + with open(file_path_dict["member"], "r") as f_member, open(file_path_dict["nonmember"], "r") as f_nonmember, open(file_path_dict["member_neighbors"], "r") as f_member_neighbors, open(file_path_dict["nonmember_neighbors"], "r") as f_nonmember_neighbors: + for id, (member, nonmember, member_neighbors, nonmember_neighbors) in enumerate(zip(f_member, f_nonmember, f_member_neighbors, f_nonmember_neighbors)): + yield id, { + "member": json.loads(member), + "nonmember": json.loads(nonmember), + "member_neighbors": json.loads(member_neighbors)[0], + "nonmember_neighbors": json.loads(nonmember_neighbors)[0], + } \ No newline at end of file diff --git a/run.py b/run.py index ce17285..ea5f96d 100644 --- a/run.py +++ b/run.py @@ -527,7 +527,7 @@ def main(config: ExperimentConfig): other_objs.append(data_obj_nonmem_others) other_nonmembers.append(data_nonmember_others) - if config.dump_cache and not config.load_from_cache: + if config.dump_cache and not (config.load_from_cache or config.load_from_hf): print("Data dumped! Please re-run with load_from_cache set to True") exit(0) diff --git a/setup.py b/setup.py index f22301e..00bf06f 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ description="Python package for measuring memorization in LLMs", author="Anshuman Suri, Michael Duan, Niloofar Mireshghallah", author_email="as9rw@virginia.edu", - version="0.9", + version="1.0", url="https://github.com/iamgroot42/mimir", license="MIT", python_requires=">=3.9",