Skip to content

Commit

Permalink
Tested HF loading for NE attack
Browse files Browse the repository at this point in the history
  • Loading branch information
iamgroot42 committed Feb 13, 2024
1 parent 2773275 commit 29352bc
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 69 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,16 @@ To add an attack, create a file for your attack (e.g. `attacks/my_attack.py`) an
Then, add a name for your attack to the dictionary in `attacks/blackbox_attack.py`.

If you would like to submit your attack to the repository, please open a pull request describing your attack and the paper it is based on.

## Citation

If you use MIMIR in your research, please cite our paper:

```bibtex
@article{duan2024membership,
title={Do Membership Inference Attacks Work on Large Language Models?},
author={Michael Duan and Anshuman Suri and Niloofar Mireshghallah and Sewon Min and Weijia Shi and Luke Zettlemoyer and Yulia Tsvetkov and Yejin Choi and David Evans and Hannaneh Hajishirzi},
year={2024},
journal={arXiv:2402.07841},
}
```
7 changes: 5 additions & 2 deletions mimir/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,17 @@ class ExperimentConfig(Serializable):
"""OpenAI config"""

def __post_init__(self):
if self.dump_cache and self.load_from_cache:
if self.dump_cache and (self.load_from_cache or self.load_from_hf):
raise ValueError("Cannot dump and load cache at the same time")

if self.neighborhood_config:
if (
self.neighborhood_config.dump_cache
or self.neighborhood_config.load_from_cache
) and not (self.load_from_cache or self.dump_cache):
) and not (self.load_from_cache or self.dump_cache or self.load_from_hf):
raise ValueError(
"Using dump/load for neighborhood cache without dumping/loading main cache does not make sense"
)

if self.neighborhood_config.dump_cache and (self.neighborhood_config.load_from_cache or self.load_from_hf):
raise ValueError("Cannot dump and load neighborhood cache at the same time")
57 changes: 41 additions & 16 deletions mimir/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,47 @@ def load_cached(cache_dir,
print("Loading from HuggingFace!")
data_split = data_split.replace("train", "member")
data_split = data_split.replace("test", "nonmember")
ds = datasets.load_dataset("iamgroot42/mimir", name=filename, split=data_split)
data = collect_hf_data(ds)
if len(data) != n_samples:
raise ValueError(f"Requested {n_samples} samples, but only {len(data)} samples available. Potential mismatch in HuggingFace data and requested data.")
if not filename.startswith("the_pile"):
raise ValueError(f"HuggingFace data only available for The Pile.")

SOURCES_UPLOADED = [
"arxiv",
"dm_mathematics",
"github",
"hackernews",
"pile_cc",
"pubmed_central",
"wikipedia_(en)",
"full_pile",
"c4",
"temporal_arxiv",
"temporal_wiki"
]

for source in SOURCES_UPLOADED:
# Got a match
if source in filename and filename.startswith(f"the_pile_{source}"):
split = filename.split(f"the_pile_{source}")[1]
if split == "":
# The way HF data is uploaded, no split is recorded as "none"
split = "none"
else:
# remove the first underscore
split = split[1:]
# remove '<' , '>'
split = split.replace("<", "").replace(">", "")
# Remove "_truncated" from the end, if present
split = split.rsplit("_truncated", 1)[0]

# Load corresponding dataset
ds = datasets.load_dataset("iamgroot42/mimir", name=source, split=split)
data = ds[data_split]
# Check if the number of samples is correct
if len(data) != n_samples:
raise ValueError(f"Requested {n_samples} samples, but only {len(data)} samples available. Potential mismatch in HuggingFace data and requested data.")
return data
# If got here, matching source was not found
raise ValueError(f"Requested source {filename} not found in HuggingFace data.")
else:
file_path = os.path.join(cache_dir, f"cache_{min_length}_{max_length}_{n_samples}_{max_tokens}", data_split, filename + ".jsonl")
if not os.path.exists(file_path):
Expand All @@ -50,18 +87,6 @@ def load_cached(cache_dir,
return data


def collect_hf_data(ds):
"""
Helper function to collect all data from a given HuggingFace dataset split.
"""
records = [x["text"] for x in ds]
# Standard DS
if len(records[0]) == 1:
records = [x[0] for x in records]
# Neighbor data
return records


def load_data(file_path):
"""
Load data from a given filepath (.jsonl)
Expand Down
2 changes: 1 addition & 1 deletion mimir/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def load(self, train: bool, mask_tokenizer=None, specific_source: str = None):
assert not self.config.full_doc
data = np.load(self.presampled)
return data
elif self.config.load_from_cache:
elif (self.config.load_from_cache or self.config.load_from_hf):
# Load from cache, if requested
filename = self._get_name_to_save()
data = custom_datasets.load_cached(
Expand Down
17 changes: 13 additions & 4 deletions python_scripts/check_hf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from datasets import load_dataset

ds = load_dataset("mimir.py", "the_pile_full_pile")
print(ds)
ds = load_dataset("mimir.py", "the_pile_arxiv")
print(ds['member'][0]['text'])
ds = load_dataset("mimir.py", "pile_cc", split="ngram_7_0.2")
print(ds['member'][0])
ds = load_dataset("mimir.py", "full_pile", split="none")
print(len(ds['member']))
assert len(ds['member']) == 10000
print(ds["nonmember_neighbors"][0])
print(len(ds["member_neighbors"]))
print(ds['member_neighbors'][0][12])
ds = load_dataset("mimir.py", "arxiv", split="ngram_13_0.8")
print(ds["nonmember_neighbors"][1][9])

assert len(ds['member']) == 1000
assert len(ds["nonmember_neighbors"][0]) == 25
148 changes: 104 additions & 44 deletions python_scripts/mimir.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import datasets

from typing import List


_HOMEPAGE = "http://github.com/iamgroot42/mimir"

Expand All @@ -23,11 +25,11 @@
"""

_CITATION = """\
@article{duan2024do,
title={Do Membership Inference Attacks Work on Large Language Models?},
author={Duan*, Michael and \textbf{A. Suri*} and Mireshghallah, Niloofar and Min, Sewon and Shi, Weijia and Zettlemoyer, Luke and Tsvetkov, Yulia and Choi, Yejin and Evans, David and Hajishirzi, Hannaneh},
journal={arXiv preprint arXiv:???},
year={2024}
@article{duan2024membership,
title={Do Membership Inference Attacks Work on Large Language Models?},
author={Michael Duan and Anshuman Suri and Niloofar Mireshghallah and Sewon Min and Weijia Shi and Luke Zettlemoyer and Yulia Tsvetkov and Yejin Choi and David Evans and Hannaneh Hajishirzi},
year={2024},
journal={arXiv:2402.07841},
}
"""

Expand All @@ -37,27 +39,71 @@
class MimirConfig(BuilderConfig):
"""BuilderConfig for Mimir dataset."""

def __init__(self, **kwargs):
def __init__(self, *args, subsets: List[str]=[], **kwargs):
"""Constructs a MimirConfig.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(MimirConfig, self).__init__(**kwargs)
self.subsets = subsets


class MimirDataset(GeneratorBasedBuilder):
# Assuming 'VERSION' is defined
VERSION = datasets.Version("1.0.0")
VERSION = datasets.Version("1.3.0")

# Define the builder configs
BUILDER_CONFIG_CLASS = MimirConfig
BUILDER_CONFIGS = [
MimirConfig(
name="the_pile_arxiv", description="This split contains data from Arxiv"
name="arxiv",
subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"],
description="This split contains data from the Pile's Arxiv subset at various n-gram overlap thresholds"
),
MimirConfig(
name="dm_mathematics",
subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"],
description="This split contains data from the Pile's DM Mathematics subset at various n-gram overlap thresholds"
),
MimirConfig(
name="github",
subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"],
description="This split contains data from the Pile's GitHub subset at various n-gram overlap thresholds"
),
MimirConfig(
name="hackernews",
subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"],
description="This split contains data from the Pile's HackerNews subset at various n-gram overlap thresholds"
),
MimirConfig(
name="pile_cc",
subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"],
description="This split contains data from the Pile's Pile CC subset at various n-gram overlap thresholds"
),
MimirConfig(
name="pubmed_central",
subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"],
description="This split contains data from the Pile's PubMed Central subset at various n-gram overlap thresholds"
),
MimirConfig(
name="wikipedia_(en)",
subsets=["ngram_7_0.2", "ngram_13_0.2", "ngram_13_0.8"],
description="This split contains data from the Pile's Wikipedia subset at various n-gram overlap thresholds"
),
MimirConfig(
name="the_pile_full_pile", description="This split contains data from multiple sources in the Pile",
name="full_pile", description="This split contains data from multiple sources in the Pile",
),
MimirConfig(
name="c4", description="This split contains data the C4 dataset",
),
MimirConfig(
name="temporal_arxiv",
subsets=["2020_08", "2021_01", "2021_06", "2022_01", "2022_06", "2023_01", "2023_06"],
description="This split contains benchmarks where non-members are selected from various months from 2020-08 and onwards",
),
MimirConfig(
name="temporal_wiki", description="This split contains benchmarks where non-members are selected from 2023-08 and onwards",
),
]

Expand All @@ -66,17 +112,20 @@ def _info(self):
# This is the description that will appear on the datasets page.
description=_DESCRIPTION,
# This defines the different columns of the dataset and their types
features=datasets.Features(
{"text": datasets.Sequence(datasets.Value("string"))}
),
features=datasets.Features({
"member": datasets.Value("string"),
"nonmember": datasets.Value("string"),
"member_neighbors": datasets.Sequence(datasets.Value("string")),
"nonmember_neighbors": datasets.Sequence(datasets.Value("string"))
}),
# If there's a common (input, target) tuple from the features,
# specify them here. They'll be used if as_supervised=True in
# builder.as_dataset.
supervised_keys=None,
# Homepage of the dataset for documentation
homepage=_HOMEPAGE,
# Citation for the dataset
# citation=_CITATION,
citation=_CITATION,
)

def _split_generators(self, dl_manager: DownloadManager):
Expand All @@ -85,50 +134,61 @@ def _split_generators(self, dl_manager: DownloadManager):
NEIGHBOR_SUFFIX = "_neighbors_25_bert_in_place_swap"
parent_dir = (
"cache_100_200_10000_512"
if self.config.name == "the_pile_full_pile"
if self.config.name == "full_pile"
else "cache_100_200_1000_512"
)

file_paths = {
"member": os.path.join(parent_dir, "train", self.config.name + ".jsonl"),
"nonmember": os.path.join(parent_dir, "test", self.config.name + ".jsonl"),
}
# Load neighbor splits if they exist
# TODO: This is not correct (should be checking URL, not local file structure). Fix later
if os.path.exists(
os.path.join(
parent_dir,
"train_neighbors",
self.config.name + f"{NEIGHBOR_SUFFIX}.jsonl",
)
):
# Assume if train nieghbors exist, test neighbors also exist
file_paths["member_neighbors"] = os.path.join(
if len(self.config.subsets) > 0:
suffixes = [f"{subset}" for subset in self.config.subsets]
else:
suffixes = ["none"]

file_paths = {}
for subset_split_suffix in suffixes:
internal_fp = {}

subset_split_suffix_use = f"_{subset_split_suffix}" if subset_split_suffix != "none" else ""

# Add standard member and non-member paths
internal_fp['member'] = os.path.join(parent_dir, "train", f"{self.config.name}{subset_split_suffix_use}.jsonl")
internal_fp['nonmember'] = os.path.join(parent_dir, "test", f"{self.config.name}{subset_split_suffix_use}.jsonl")

# Load associated neighbors
internal_fp['member_neighbors'] = os.path.join(
parent_dir,
"train_neighbors",
self.config.name + f"{NEIGHBOR_SUFFIX}.jsonl",
f"{self.config.name}{subset_split_suffix_use}{NEIGHBOR_SUFFIX}.jsonl",
)
file_paths["nonmember_neighbors"] = os.path.join(
internal_fp['nonmember_neighbors'] = os.path.join(
parent_dir,
"test_neighbors",
self.config.name + f"{NEIGHBOR_SUFFIX}.jsonl",
f"{self.config.name}{subset_split_suffix_use}{NEIGHBOR_SUFFIX}.jsonl",
)
file_paths[subset_split_suffix] = internal_fp

# Now that we know which files to load, download them
download_paths = [_DOWNLOAD_URL + v for v in file_paths.values()]
data_dir = dl_manager.download_and_extract(download_paths)
data_dir = {}
for k, v_dict in file_paths.items():
download_paths = []
for v in v_dict.values():
download_paths.append(_DOWNLOAD_URL + v)
paths = dl_manager.download_and_extract(download_paths)
internal_dict = {k:v for k, v in zip(v_dict.keys(), paths)}
data_dir[k] = internal_dict

splits = []
for i, k in enumerate(file_paths.keys()):
splits.append(SplitGenerator(name=k, gen_kwargs={"file_path": data_dir[i]}))
for k in suffixes:
splits.append(SplitGenerator(name=k, gen_kwargs={"file_path_dict": data_dir[k]}))
return splits

def _generate_examples(self, file_path):
def _generate_examples(self, file_path_dict):
"""Yields examples."""
# Open the specified .jsonl file and read each line
with open(file_path, "r") as f:
for id, line in enumerate(f):
data = json.loads(line)
if type(data) != list:
data = [data]
yield id, {"text": data}
# Open all four files in file_path_dict and yield examples (one from each file) simultaneously
with open(file_path_dict["member"], "r") as f_member, open(file_path_dict["nonmember"], "r") as f_nonmember, open(file_path_dict["member_neighbors"], "r") as f_member_neighbors, open(file_path_dict["nonmember_neighbors"], "r") as f_nonmember_neighbors:
for id, (member, nonmember, member_neighbors, nonmember_neighbors) in enumerate(zip(f_member, f_nonmember, f_member_neighbors, f_nonmember_neighbors)):
yield id, {
"member": json.loads(member),
"nonmember": json.loads(nonmember),
"member_neighbors": json.loads(member_neighbors)[0],
"nonmember_neighbors": json.loads(nonmember_neighbors)[0],
}
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def main(config: ExperimentConfig):
other_objs.append(data_obj_nonmem_others)
other_nonmembers.append(data_nonmember_others)

if config.dump_cache and not config.load_from_cache:
if config.dump_cache and not (config.load_from_cache or config.load_from_hf):
print("Data dumped! Please re-run with load_from_cache set to True")
exit(0)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
description="Python package for measuring memorization in LLMs",
author="Anshuman Suri, Michael Duan, Niloofar Mireshghallah",
author_email="[email protected]",
version="0.9",
version="1.0",
url="https://github.com/iamgroot42/mimir",
license="MIT",
python_requires=">=3.9",
Expand Down

0 comments on commit 29352bc

Please sign in to comment.