diff --git a/src/openqdc/datasets/base.py b/src/openqdc/datasets/base.py index 254ff0d..1c1a2c6 100644 --- a/src/openqdc/datasets/base.py +++ b/src/openqdc/datasets/base.py @@ -3,6 +3,7 @@ import numpy as np import torch +from loguru import logger from sklearn.utils import Bunch from tqdm import tqdm @@ -124,6 +125,7 @@ def collate_list(self, list_entries): def save_preprocess(self, data_dict): # save memmaps + logger.info("Preprocessing data and saving it to cache.") for key in self.data_keys: local_path = p_join(self.preprocess_path, f"{key}.mmap") out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape) @@ -140,6 +142,7 @@ def save_preprocess(self, data_dict): push_remote(local_path) def read_preprocess(self): + logger.info("Reading preprocessed data") self.data = {} for key in self.data_keys: filename = p_join(self.preprocess_path, f"{key}.mmap") @@ -172,14 +175,17 @@ def __len__(self): def __getitem__(self, idx: int): p_start, p_end = self.data["position_idx_range"][idx] input = self.data["atomic_inputs"][p_start:p_end] - z, c, positions = input[:, 0], input[:, 1], input[:, -3:] - z, c = z.astype(np.int32), c.astype(np.int32) - energies = self.data["energies"][idx] + z, c, positions, energies = ( + np.array(input[:, 0], dtype=np.int32), + np.array(input[:, 1], dtype=np.int32), + np.array(input[:, -3:], dtype=np.float32), + np.array(self.data["energies"][idx], dtype=np.float32), + ) name = self.data["name"]["uniques"][self.data["name"]["inv_indices"][idx]] subset = self.data["subset"]["uniques"][self.data["subset"]["inv_indices"][idx]] if "forces" in self.data: - forces = self.data["forces"][p_start:p_end] + forces = np.array(self.data["forces"][p_start:p_end], dtype=np.float32) else: forces = None diff --git a/src/openqdc/datasets/geom.py b/src/openqdc/datasets/geom.py index c2f895a..6af826e 100644 --- a/src/openqdc/datasets/geom.py +++ b/src/openqdc/datasets/geom.py @@ -6,7 +6,7 @@ from openqdc.datasets.base import BaseDataset from openqdc.utils import load_json, load_pkl from openqdc.utils.constants import MAX_ATOMIC_NUMBER -from openqdc.utils.molecule import get_atomic_numuber_and_charge +from openqdc.utils.molecule import get_atomic_number_and_charge def read_mol(mol_id, mol_dict, base_path, partition): @@ -34,7 +34,7 @@ def read_mol(mol_id, mol_dict, base_path, partition): try: d = load_pkl(p_join(base_path, mol_dict["pickle_path"]), False) confs = d["conformers"] - x = get_atomic_numuber_and_charge(confs[0]["rd_mol"]) + x = get_atomic_number_and_charge(confs[0]["rd_mol"]) positions = np.array([cf["rd_mol"].GetConformer().GetPositions() for cf in confs]) n_confs = positions.shape[0] diff --git a/src/openqdc/datasets/molecule3d.py b/src/openqdc/datasets/molecule3d.py index ac4f348..0d59400 100644 --- a/src/openqdc/datasets/molecule3d.py +++ b/src/openqdc/datasets/molecule3d.py @@ -9,13 +9,13 @@ from openqdc.datasets.base import BaseDataset from openqdc.utils.constants import BOHR2ANG, MAX_ATOMIC_NUMBER -from openqdc.utils.molecule import get_atomic_numuber_and_charge +from openqdc.utils.molecule import get_atomic_number_and_charge def read_mol(mol, energy): smiles = dm.to_smiles(mol, explicit_hs=False) # subset = dm.to_smiles(dm.to_scaffold_murcko(mol, make_generic=True), explicit_hs=False) - x = get_atomic_numuber_and_charge(mol) + x = get_atomic_number_and_charge(mol) positions = mol.GetConformer().GetPositions() * BOHR2ANG res = dict( diff --git a/src/openqdc/datasets/qmugs.py b/src/openqdc/datasets/qmugs.py index 6868f38..b528f42 100644 --- a/src/openqdc/datasets/qmugs.py +++ b/src/openqdc/datasets/qmugs.py @@ -7,7 +7,7 @@ from openqdc.datasets.base import BaseDataset from openqdc.utils.constants import MAX_ATOMIC_NUMBER -from openqdc.utils.molecule import get_atomic_numuber_and_charge +from openqdc.utils.molecule import get_atomic_number_and_charge def read_mol(mol_dir): @@ -19,7 +19,7 @@ def read_mol(mol_dir): return None smiles = dm.to_smiles(mols[0], explicit_hs=False) - x = get_atomic_numuber_and_charge(mols[0])[None, ...].repeat(n_confs, axis=0) + x = get_atomic_number_and_charge(mols[0])[None, ...].repeat(n_confs, axis=0) positions = np.array([mol.GetConformer().GetPositions() for mol in mols]) props = [mol.GetPropsAsDict() for mol in mols] targets = np.array([[p[el] for el in QMugs.energy_target_names] for p in props]) diff --git a/src/openqdc/datasets/spice.py b/src/openqdc/datasets/spice.py index 0aec9b2..88af6dc 100644 --- a/src/openqdc/datasets/spice.py +++ b/src/openqdc/datasets/spice.py @@ -7,14 +7,14 @@ from openqdc.datasets.base import BaseDataset from openqdc.utils import load_hdf5_file from openqdc.utils.constants import BOHR2ANG, MAX_ATOMIC_NUMBER -from openqdc.utils.molecule import get_atomic_numuber_and_charge +from openqdc.utils.molecule import get_atomic_number_and_charge def read_record(r): smiles = r["smiles"].asstr()[0] subset = r["subset"][0].decode("utf-8") n_confs = r["conformations"].shape[0] - x = get_atomic_numuber_and_charge(dm.to_mol(smiles, add_hs=True)) + x = get_atomic_number_and_charge(dm.to_mol(smiles, add_hs=True)) positions = r["conformations"][:] * BOHR2ANG res = dict( diff --git a/src/openqdc/utils/io.py b/src/openqdc/utils/io.py index 0391add..0a5f7c5 100644 --- a/src/openqdc/utils/io.py +++ b/src/openqdc/utils/io.py @@ -34,7 +34,7 @@ def push_remote(local_path, overwrite=True): return remote_path -def pull_locally(local_path, overwrite=True): +def pull_locally(local_path, overwrite=False): remote_path = local_path.replace(get_local_cache(), get_remote_cache()) os.makedirs(os.path.dirname(local_path), exist_ok=True) if not os.path.exists(local_path) or overwrite: diff --git a/src/openqdc/utils/molecule.py b/src/openqdc/utils/molecule.py index e8c1c9c..cd2290f 100644 --- a/src/openqdc/utils/molecule.py +++ b/src/openqdc/utils/molecule.py @@ -14,6 +14,6 @@ def get_atomic_charge(mol: Chem.Mol): return np.array([atom.GetFormalCharge() for atom in mol.GetAtoms()]) -def get_atomic_numuber_and_charge(mol: Chem.Mol): +def get_atomic_number_and_charge(mol: Chem.Mol): """Returns atoms number and charge for rdkit molecule""" return np.array([[atom.GetAtomicNum(), atom.GetFormalCharge()] for atom in mol.GetAtoms()])