Skip to content

Commit

Permalink
Merge pull request #4 from OpenDrugDiscovery/minor-fixes
Browse files Browse the repository at this point in the history
Minor fixes
  • Loading branch information
shenoynikhil authored Sep 26, 2023
2 parents 56834ee + 77f1efe commit 5ec94bd
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 14 deletions.
14 changes: 10 additions & 4 deletions src/openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
from loguru import logger
from sklearn.utils import Bunch
from tqdm import tqdm

Expand Down Expand Up @@ -124,6 +125,7 @@ def collate_list(self, list_entries):

def save_preprocess(self, data_dict):
# save memmaps
logger.info("Preprocessing data and saving it to cache.")
for key in self.data_keys:
local_path = p_join(self.preprocess_path, f"{key}.mmap")
out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape)
Expand All @@ -140,6 +142,7 @@ def save_preprocess(self, data_dict):
push_remote(local_path)

def read_preprocess(self):
logger.info("Reading preprocessed data")
self.data = {}
for key in self.data_keys:
filename = p_join(self.preprocess_path, f"{key}.mmap")
Expand Down Expand Up @@ -172,14 +175,17 @@ def __len__(self):
def __getitem__(self, idx: int):
p_start, p_end = self.data["position_idx_range"][idx]
input = self.data["atomic_inputs"][p_start:p_end]
z, c, positions = input[:, 0], input[:, 1], input[:, -3:]
z, c = z.astype(np.int32), c.astype(np.int32)
energies = self.data["energies"][idx]
z, c, positions, energies = (
np.array(input[:, 0], dtype=np.int32),
np.array(input[:, 1], dtype=np.int32),
np.array(input[:, -3:], dtype=np.float32),
np.array(self.data["energies"][idx], dtype=np.float32),
)
name = self.data["name"]["uniques"][self.data["name"]["inv_indices"][idx]]
subset = self.data["subset"]["uniques"][self.data["subset"]["inv_indices"][idx]]

if "forces" in self.data:
forces = self.data["forces"][p_start:p_end]
forces = np.array(self.data["forces"][p_start:p_end], dtype=np.float32)
else:
forces = None

Expand Down
4 changes: 2 additions & 2 deletions src/openqdc/datasets/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from openqdc.datasets.base import BaseDataset
from openqdc.utils import load_json, load_pkl
from openqdc.utils.constants import MAX_ATOMIC_NUMBER
from openqdc.utils.molecule import get_atomic_numuber_and_charge
from openqdc.utils.molecule import get_atomic_number_and_charge


def read_mol(mol_id, mol_dict, base_path, partition):
Expand Down Expand Up @@ -34,7 +34,7 @@ def read_mol(mol_id, mol_dict, base_path, partition):
try:
d = load_pkl(p_join(base_path, mol_dict["pickle_path"]), False)
confs = d["conformers"]
x = get_atomic_numuber_and_charge(confs[0]["rd_mol"])
x = get_atomic_number_and_charge(confs[0]["rd_mol"])
positions = np.array([cf["rd_mol"].GetConformer().GetPositions() for cf in confs])
n_confs = positions.shape[0]

Expand Down
4 changes: 2 additions & 2 deletions src/openqdc/datasets/molecule3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

from openqdc.datasets.base import BaseDataset
from openqdc.utils.constants import BOHR2ANG, MAX_ATOMIC_NUMBER
from openqdc.utils.molecule import get_atomic_numuber_and_charge
from openqdc.utils.molecule import get_atomic_number_and_charge


def read_mol(mol, energy):
smiles = dm.to_smiles(mol, explicit_hs=False)
# subset = dm.to_smiles(dm.to_scaffold_murcko(mol, make_generic=True), explicit_hs=False)
x = get_atomic_numuber_and_charge(mol)
x = get_atomic_number_and_charge(mol)
positions = mol.GetConformer().GetPositions() * BOHR2ANG

res = dict(
Expand Down
4 changes: 2 additions & 2 deletions src/openqdc/datasets/qmugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from openqdc.datasets.base import BaseDataset
from openqdc.utils.constants import MAX_ATOMIC_NUMBER
from openqdc.utils.molecule import get_atomic_numuber_and_charge
from openqdc.utils.molecule import get_atomic_number_and_charge


def read_mol(mol_dir):
Expand All @@ -19,7 +19,7 @@ def read_mol(mol_dir):
return None

smiles = dm.to_smiles(mols[0], explicit_hs=False)
x = get_atomic_numuber_and_charge(mols[0])[None, ...].repeat(n_confs, axis=0)
x = get_atomic_number_and_charge(mols[0])[None, ...].repeat(n_confs, axis=0)
positions = np.array([mol.GetConformer().GetPositions() for mol in mols])
props = [mol.GetPropsAsDict() for mol in mols]
targets = np.array([[p[el] for el in QMugs.energy_target_names] for p in props])
Expand Down
4 changes: 2 additions & 2 deletions src/openqdc/datasets/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from openqdc.datasets.base import BaseDataset
from openqdc.utils import load_hdf5_file
from openqdc.utils.constants import BOHR2ANG, MAX_ATOMIC_NUMBER
from openqdc.utils.molecule import get_atomic_numuber_and_charge
from openqdc.utils.molecule import get_atomic_number_and_charge


def read_record(r):
smiles = r["smiles"].asstr()[0]
subset = r["subset"][0].decode("utf-8")
n_confs = r["conformations"].shape[0]
x = get_atomic_numuber_and_charge(dm.to_mol(smiles, add_hs=True))
x = get_atomic_number_and_charge(dm.to_mol(smiles, add_hs=True))
positions = r["conformations"][:] * BOHR2ANG

res = dict(
Expand Down
2 changes: 1 addition & 1 deletion src/openqdc/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def push_remote(local_path, overwrite=True):
return remote_path


def pull_locally(local_path, overwrite=True):
def pull_locally(local_path, overwrite=False):
remote_path = local_path.replace(get_local_cache(), get_remote_cache())
os.makedirs(os.path.dirname(local_path), exist_ok=True)
if not os.path.exists(local_path) or overwrite:
Expand Down
2 changes: 1 addition & 1 deletion src/openqdc/utils/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ def get_atomic_charge(mol: Chem.Mol):
return np.array([atom.GetFormalCharge() for atom in mol.GetAtoms()])


def get_atomic_numuber_and_charge(mol: Chem.Mol):
def get_atomic_number_and_charge(mol: Chem.Mol):
"""Returns atoms number and charge for rdkit molecule"""
return np.array([[atom.GetAtomicNum(), atom.GetFormalCharge()] for atom in mol.GetAtoms()])

0 comments on commit 5ec94bd

Please sign in to comment.