diff --git a/src/openqdc/datasets/base.py b/src/openqdc/datasets/base.py index 1c1a2c6..392144d 100644 --- a/src/openqdc/datasets/base.py +++ b/src/openqdc/datasets/base.py @@ -159,10 +159,12 @@ def read_preprocess(self): for key in ["name", "subset"]: filename = p_join(self.preprocess_path, f"{key}.npz") pull_locally(filename) - # with open(filename, "rb") as f: - self.data[key] = np.load(open(filename, "rb")) - for k in self.data[key]: - print(f"Loaded {key}_{k} with shape {self.data[key][k].shape}, dtype {self.data[key][k].dtype}") + self.data[key] = dict() + with open(filename, "rb") as f: + tmp = np.load(f) + for k in tmp: + self.data[key][k] = tmp[k] + print(f"Loaded {key}_{k} with shape {self.data[key][k].shape}, dtype {self.data[key][k].dtype}") def is_preprocessed(self): predicats = [copy_exists(p_join(self.preprocess_path, f"{key}.mmap")) for key in self.data_keys]