From b136dfcc95fe6a5b4c04242677442f4f4fbea2ee Mon Sep 17 00:00:00 2001 From: Anshuman Suri Date: Thu, 8 Feb 2024 17:07:40 -0500 Subject: [PATCH] Load data from HF --- mimir/custom_datasets.py | 30 ++++++++++++++++++---- mimir/data_utils.py | 2 ++ python_scripts/{test_hf.py => check_hf.py} | 0 3 files changed, 27 insertions(+), 5 deletions(-) rename python_scripts/{test_hf.py => check_hf.py} (100%) diff --git a/mimir/custom_datasets.py b/mimir/custom_datasets.py index 065e730..4a45355 100644 --- a/mimir/custom_datasets.py +++ b/mimir/custom_datasets.py @@ -22,18 +22,38 @@ def load_pubmed(cache_dir): return data -def load_cached(cache_dir, path: str, filename: str, min_length: int, max_length: int, n_samples: int, max_tokens: int): +def load_cached(cache_dir, data_split: str, filename: str, min_length: int, + max_length: int, n_samples: int, max_tokens: int, + load_from_hf: bool = False): """" Read from cache if available. Used for certain pile sources and xsum to ensure fairness in comparison across attacks.runs. """ - file_path = os.path.join(cache_dir, f"cache_{min_length}_{max_length}_{n_samples}_{max_tokens}", path, filename + ".jsonl") - if not os.path.exists(file_path): - raise ValueError(f"Requested cache file {file_path} does not exist") - data = load_data(file_path) + if load_from_hf: + print("Loading from HuggingFace!") + data_split = data_split.replace("train", "member") + data_split = data_split.replace("test", "nonmember") + ds = datasets.load_dataset("iamgroot42/mimir", name=filename, split=data_split) + data = collect_hf_data(ds) + if len(data) != n_samples: + raise ValueError(f"Requested {n_samples} samples, but only {len(data)} samples available. Potential mismatch in HuggingFace data and requested data.") + else: + file_path = os.path.join(cache_dir, f"cache_{min_length}_{max_length}_{n_samples}_{max_tokens}", data_split, filename + ".jsonl") + if not os.path.exists(file_path): + raise ValueError(f"Requested cache file {file_path} does not exist") + data = load_data(file_path) return data +def collect_hf_data(ds): + records = [x["text"] for x in ds] + # Standard DS + if len(records[0]) == 1: + records = [x[0] for x in records] + # Neighbor data + return records + + def load_data(file_path): with open(file_path, 'r') as f: data = [json.loads(line) for line in f.readlines()] diff --git a/mimir/data_utils.py b/mimir/data_utils.py index 1e61d9e..55d3183 100644 --- a/mimir/data_utils.py +++ b/mimir/data_utils.py @@ -51,6 +51,7 @@ def load_neighbors( max_length=self.config.max_words, n_samples=self.config.n_samples, max_tokens=self.config.max_tokens, + load_from_hf=self.config.load_from_hf ) return data @@ -102,6 +103,7 @@ def load(self, train: bool, mask_tokenizer=None, specific_source: str = None): max_length=self.config.max_words, n_samples=self.config.n_samples, max_tokens=self.config.max_tokens, + load_from_hf=self.config.load_from_hf ) return data else: diff --git a/python_scripts/test_hf.py b/python_scripts/check_hf.py similarity index 100% rename from python_scripts/test_hf.py rename to python_scripts/check_hf.py