diff --git a/mimir/config.py b/mimir/config.py index 13d5d57..d6584c5 100644 --- a/mimir/config.py +++ b/mimir/config.py @@ -146,7 +146,7 @@ class ExperimentConfig(Serializable): "Dump data to cache? Exits program after dumping" load_from_cache: Optional[bool] = False """Load data from cache?""" - load_from_hf: Optional[bool] = False + load_from_hf: Optional[bool] = True """Load data from HuggingFace?""" blackbox_attacks: Optional[List[str]] = field( default_factory=lambda: None diff --git a/mimir/custom_datasets.py b/mimir/custom_datasets.py index 29ce84e..dd84219 100644 --- a/mimir/custom_datasets.py +++ b/mimir/custom_datasets.py @@ -12,6 +12,20 @@ DATASETS = ['writing', 'english', 'german', 'pubmed'] +SOURCES_UPLOADED = [ + "arxiv", + "dm_mathematics", + "github", + "hackernews", + "pile_cc", + "pubmed_central", + "wikipedia_(en)", + "full_pile", + "c4", + "temporal_arxiv", + "temporal_wiki" +] + def load_pubmed(cache_dir): data = datasets.load_dataset('pubmed_qa', 'pqa_labeled', split='train', cache_dir=cache_dir) @@ -41,20 +55,6 @@ def load_cached(cache_dir, if not filename.startswith("the_pile"): raise ValueError(f"HuggingFace data only available for The Pile.") - SOURCES_UPLOADED = [ - "arxiv", - "dm_mathematics", - "github", - "hackernews", - "pile_cc", - "pubmed_central", - "wikipedia_(en)", - "full_pile", - "c4", - "temporal_arxiv", - "temporal_wiki" - ] - for source in SOURCES_UPLOADED: # Got a match if source in filename and filename.startswith(f"the_pile_{source}"): diff --git a/python_scripts/mimir.py b/python_scripts/mimir.py index 8547afe..cf1b0a9 100644 --- a/python_scripts/mimir.py +++ b/python_scripts/mimir.py @@ -189,6 +189,6 @@ def _generate_examples(self, file_path_dict): yield id, { "member": json.loads(member), "nonmember": json.loads(nonmember), - "member_neighbors": json.loads(member_neighbors)[0], - "nonmember_neighbors": json.loads(nonmember_neighbors)[0], + "member_neighbors": json.loads(member_neighbors), + "nonmember_neighbors": json.loads(nonmember_neighbors), } \ No newline at end of file diff --git a/run.py b/run.py index ea5f96d..ec9f388 100644 --- a/run.py +++ b/run.py @@ -190,7 +190,7 @@ def get_mia_scores( else None ), loss=loss, - batch_siz=4, + batch_size=4, substr_neighbors=substr_neighbors, )