Skip to content

Commit

Permalink
Merge pull request #100 from OpenDrugDiscovery/units_improvements
Browse files Browse the repository at this point in the history
Units improvements + qm1b
  • Loading branch information
FNTwin authored Jun 25, 2024
2 parents 1885a0b + d6a0d1e commit e1190e3
Show file tree
Hide file tree
Showing 11 changed files with 837 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
os: ["ubuntu-latest"]

runs-on: ${{ matrix.os }}
Expand Down
5 changes: 5 additions & 0 deletions openqdc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def get_project_root():
"SN2RXN": "openqdc.datasets.potential.sn2_rxn",
"QM7X": "openqdc.datasets.potential.qm7x",
"QM7X_V2": "openqdc.datasets.potential.qm7x",
"QM1B": "openqdc.datasets.potential.qm1b",
"QM1B_SMALL": "openqdc.datasets.potential.qm1b",
"NablaDFT": "openqdc.datasets.potential.nabladft",
"SolvatedPeptides": "openqdc.datasets.potential.solvated_peptides",
"WaterClusters": "openqdc.datasets.potential.waterclusters3_30",
Expand Down Expand Up @@ -101,6 +103,8 @@ def __dir__():
from .datasets.interaction.metcalf import Metcalf
from .datasets.interaction.splinter import Splinter
from .datasets.interaction.x40 import X40

# POTENTIAL
from .datasets.potential.ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X, ANI2X
from .datasets.potential.comp6 import COMP6
from .datasets.potential.dummy import Dummy
Expand All @@ -113,6 +117,7 @@ def __dir__():
from .datasets.potential.nabladft import NablaDFT
from .datasets.potential.orbnet_denali import OrbnetDenali
from .datasets.potential.pcqm import PCQM_B3LYP, PCQM_PM6
from .datasets.potential.qm1b import QM1B, QM1B_SMALL
from .datasets.potential.qm7x import QM7X, QM7X_V2
from .datasets.potential.qmugs import QMugs, QMugs_V2
from .datasets.potential.revmd17 import RevMD17
Expand Down
42 changes: 26 additions & 16 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
)
from openqdc.utils.package_utils import has_package, requires_package
from openqdc.utils.regressor import Regressor # noqa
from openqdc.utils.units import get_conversion
from openqdc.utils.units import (
DistanceTypeConversion,
EnergyTypeConversion,
ForceTypeConversion,
get_conversion,
)

if has_package("torch"):
import torch
Expand Down Expand Up @@ -129,7 +134,7 @@ def __init__(
set_cache_dir(cache_dir)
# self._init_lambda_fn()
self.data = None
self._original_unit = self.__energy_unit__
self._original_unit = self.energy_unit
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
self.transform = transform
Expand Down Expand Up @@ -225,24 +230,27 @@ def e0s_dispatcher(self):
def _convert_data(self):
logger.info(
f"Converting {self.__name__} data to the following units:\n\
Energy: {self.energy_unit},\n\
Distance: {self.distance_unit},\n\
Forces: {self.force_unit if self.__force_methods__ else 'None'}"
Energy: {str(self.energy_unit)},\n\
Distance: {str(self.distance_unit)},\n\
Forces: {str(self.force_unit) if self.__force_methods__ else 'None'}"
)
for key in self.data_keys:
self.data[key] = self._convert_on_loading(self.data[key], key)

@property
def energy_unit(self):
return self.__energy_unit__
return EnergyTypeConversion(self.__energy_unit__)

@property
def distance_unit(self):
return self.__distance_unit__
return DistanceTypeConversion(self.__distance_unit__)

@property
def force_unit(self):
return self.__forces_unit__
units = self.__forces_unit__.split("/")
if len(units) > 2:
units = ["/".join(units[:2]), units[-1]]
return ForceTypeConversion(tuple(units)) # < 3.12 compatibility

@property
def root(self):
Expand Down Expand Up @@ -291,15 +299,15 @@ def data_shapes(self):
"forces": (-1, 3, len(self.force_methods)),
}

def _set_units(self, en, ds):
def _set_units(self, en: Optional[str] = None, ds: Optional[str] = None):
old_en, old_ds = self.energy_unit, self.distance_unit
en = en if en is not None else old_en
ds = ds if ds is not None else old_ds
self.set_energy_unit(en)
self.set_distance_unit(ds)
if self.__force_methods__:
self.__forces_unit__ = self.energy_unit + "/" + self.distance_unit
self._fn_forces = get_conversion(old_en + "/" + old_ds, self.__forces_unit__)
self._fn_forces = self.force_unit.to(str(self.energy_unit), str(self.distance_unit))
self.__forces_unit__ = str(self.energy_unit) + "/" + str(self.distance_unit)

def _set_isolated_atom_energies(self):
if self.__energy_methods__ is None:
Expand All @@ -308,7 +316,7 @@ def _set_isolated_atom_energies(self):
f = get_conversion("hartree", self.__energy_unit__)
else:
# regression are calculated on the original unit of the dataset
f = get_conversion(self._original_unit, self.__energy_unit__)
f = self._original_unit.to(self.energy_unit)
self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix)

def convert_energy(self, x):
Expand All @@ -324,17 +332,19 @@ def set_energy_unit(self, value: str):
"""
Set a new energy unit for the dataset.
"""
old_unit = self.energy_unit
# old_unit = self.energy_unit
# self.__energy_unit__ = value
self._fn_energy = self.energy_unit.to(value) # get_conversion(old_unit, value)
self.__energy_unit__ = value
self._fn_energy = get_conversion(old_unit, value)

def set_distance_unit(self, value: str):
"""
Set a new distance unit for the dataset.
"""
old_unit = self.distance_unit
# old_unit = self.distance_unit
# self.__distance_unit__ = value
self._fn_distance = self.distance_unit.to(value) # get_conversion(old_unit, value)
self.__distance_unit__ = value
self._fn_distance = get_conversion(old_unit, value)

def set_array_format(self, format: str):
assert format in ["numpy", "torch", "jax"], f"Format {format} not supported."
Expand Down
2 changes: 1 addition & 1 deletion openqdc/datasets/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(
self.recompute_statistics = True
self.refit_e0s = True
self.energy_type = energy_type
self._original_unit = energy_unit
self.__energy_unit__ = energy_unit
self._original_unit = self.energy_unit
self.__distance_unit__ = distance_unit
self.__energy_methods__ = [PotentialMethod.NONE if not level_of_theory else level_of_theory]
self.energy_target_names = ["xyz"]
Expand Down
3 changes: 3 additions & 0 deletions openqdc/datasets/potential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .nabladft import NablaDFT
from .orbnet_denali import OrbnetDenali
from .pcqm import PCQM_B3LYP, PCQM_PM6
from .qm1b import QM1B, QM1B_SMALL
from .qm7x import QM7X, QM7X_V2
from .qmugs import QMugs, QMugs_V2
from .revmd17 import RevMD17
Expand Down Expand Up @@ -39,6 +40,8 @@
"QM7X_V2": QM7X_V2,
"QMugs": QMugs,
"QMugs_V2": QMugs_V2,
"QM1B": QM1B,
"QM1B_SMALL": QM1B_SMALL,
"SN2RXN": SN2RXN,
"SolvatedPeptides": SolvatedPeptides,
"Spice": Spice,
Expand Down
16 changes: 15 additions & 1 deletion openqdc/datasets/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,21 @@ class ANI1CCX_V2(ANI1CCX):


class ANI2X(ANI1):
""" """
"""
The ANI-2X dataset was constructed using active learning from modified versions of GDB-11, CheMBL,
and s66x8. It adds three new elements (F, Cl, S) resulting in 4.6 million conformers from 13k
chemical isomers, optimized using the LBFGS algorithm and labeled with ωB97X/6-31G*.
Usage
```python
from openqdc.datasets import ANI@X
dataset = ANI2X()
```
References:
- ANI-2x: https://doi.org/10.1021/acs.jctc.0c00121
- Github: https://github.com/aiqm/ANI1x_datasets
"""

__name__ = "ani2x"
__energy_unit__ = "hartree"
Expand Down
157 changes: 157 additions & 0 deletions openqdc/datasets/potential/qm1b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import os
from functools import partial
from os.path import join as p_join

import datamol as dm
import numpy as np
import pandas as pd

from openqdc.datasets.base import BaseDataset
from openqdc.methods import PotentialMethod
from openqdc.utils.io import get_local_cache

# fmt: off
FILE_NUM = [
"43005175","43005205","43005208","43005211","43005214","43005223",
"43005235","43005241","43005244","43005247","43005253","43005259",
"43005265","43005268","43005271","43005274","43005277","43005280",
"43005286","43005292","43005298","43005304","43005307","43005313",
"43005319","43005322","43005325","43005331","43005337","43005343"
"43005346","43005349","43005352","43005355","43005358","43005364",
"43005370","43005406","43005409","43005415","43005418","43005421",
"43005424","43005427","43005430","43005433","43005436","43005439",
"43005442","43005448","43005454","43005457","43005460","43005463",
"43005466","43005469","43005472","43005475","43005478","43005481",
"43005484","43005487","43005490","43005496","43005499","43005502",
"43005505","43005508","43005511","43005514","43005517","43005520",
"43005523","43005526","43005532","43005538","43005544","43005547",
"43005550","43005553","43005556","43005559","43005562","43005577",
"43005580","43005583","43005589","43005592","43005598","43005601",
"43005616","43005622","43005625","43005628","43005634","43005637",
"43005646","43005649","43005658","43005661","43005676","43006159",
"43006162","43006165","43006168","43006171","43006174","43006177",
"43006180","43006186","43006207","43006210","43006213","43006219",
"43006222","43006228","43006231","43006273","43006276","43006279",
"43006282","43006288","43006294","43006303","43006318","43006324",
"43006330","43006333","43006336","43006345","43006354","43006372",
"43006381","43006384","43006390","43006396","43006405","43006408",
"43006411","43006423","43006432","43006456","43006468","43006471",
"43006477","43006486","43006489","43006492","43006498","43006501",
"43006513","43006516","43006522","43006525","43006528","43006531",
"43006534","43006537","43006543","43006546","43006576","43006579",
"43006603","43006609","43006615","43006621","43006624","43006627",
"43006630","43006633","43006639","43006645","43006651","43006654",
"43006660","43006663","43006666","43006669","43006672","43006681",
"43006690","43006696","43006699","43006711","43006717","43006738",
"43006747","43006756","43006762","43006765","43006768","43006771",
"43006777","43006780","43006786","43006789","43006795","43006798",
"43006801","43006804","43006816","43006822","43006837","43006840",
"43006846","43006855","43006858","43006861","43006864","43006867",
"43006870","43006873","43006876","43006882","43006897","43006900",
"43006903","43006909","43006912","43006927","43006930","43006933",
"43006939","43006942","43006948","43006951","43006954","43006957",
"43006966","43006969","43006978","43006984","43006993","43006996",
"43006999","43007002","43007005","43007008","43007011","43007014",
"43007017","43007032","43007035","43007041","43007044","43007047",
"43007050","43007053","43007056","43007062","43007068","43007080",
"43007098","43007110","43007119","43007122","43007125",
]
# fmt: on


def extract_from_row(row, file_idx=None):
smiles = row["smile"]
z = np.array(row["z"])[:, None]
c = np.zeros_like(z)
x = np.concatenate((z, c), axis=1)
positions = np.array(row["pos"]).reshape(-1, 3)

res = dict(
name=np.array([smiles]),
subset=np.array(["qm1b"]) if file_idx is None else np.array([f"qm1b_{file_idx}"]),
energies=np.array([row["energy"]]).astype(np.float64)[:, None],
atomic_inputs=np.concatenate((x, positions), axis=-1, dtype=np.float32),
n_atoms=np.array([x.shape[0]], dtype=np.int32),
)
return res


class QM1B(BaseDataset):
"""
QM1B is a low-resolution DFT dataset generated using PySCF IPU.
It is composed of one billion training examples containing 9-11 heavy atoms.
It was created by taking 1.09M SMILES strings from the GDB-11 database and
computing molecular properties (e.g. HOMO-LUMO gap) for a set of up to 1000
conformers per molecule at the B3LYP/STO-3G level of theory.
Usage:
```python
from openqdc.datasets import QM1B
dataset = QM1B()
```
References:
- https://arxiv.org/pdf/2311.01135
- https://github.com/graphcore-research/qm1b-dataset/
"""

__name__ = "qm1b"

__energy_methods__ = [PotentialMethod.B3LYP_STO3G]
__force_methods__ = []

energy_target_names = ["b3lyp/sto-3g"]
force_target_names = []

__energy_unit__ = "ev"
__distance_unit__ = "bohr"
__forces_unit__ = "ev/bohr"
__links__ = {
"qm1b_validation.parquet": "https://ndownloader.figshare.com/files/43005175",
**{f"part_{i:03d}.parquet": f"https://ndownloader.figshare.com/files/{FILE_NUM[i]}" for i in range(0, 256)},
}

@property
def root(self):
return p_join(get_local_cache(), "qm1b")

@property
def preprocess_path(self):
path = p_join(self.root, "preprocessed", self.__name__)
os.makedirs(path, exist_ok=True)
return path

def read_raw_entries(self):
filenames = list(map(lambda x: p_join(self.root, f"part_{x:03d}.parquet"), list(range(0, 256)))) + [
p_join(self.root, "qm1b_validation.parquet")
]

def read_entries_parallel(filename):
df = pd.read_parquet(filename)

def extract_parallel(df, i):
return extract_from_row(df.iloc[i])

fn = partial(extract_parallel, df)
list_of_idxs = list(range(len(df)))
results = dm.utils.parallelized(fn, list_of_idxs, scheduler="threads", progress=False)
return results

list_of_list = dm.utils.parallelized(read_entries_parallel, filenames, scheduler="processes", progress=True)

return [x for xs in list_of_list for x in xs]


class QM1B_SMALL(QM1B):
"""
QM1B_SMALL is a subset of the QM1B dataset containing a
maximum of 15 random conformers per molecule.
Usage:
```python
from openqdc.datasets import QM1B_SMALL
dataset = QM1B_SMALL()
```
"""

__name__ = "qm1b_small"
4 changes: 3 additions & 1 deletion openqdc/methods/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class StrEnum(str, Enum):
def __str__(self):
return self.value
return self.value.lower()


@unique
Expand Down Expand Up @@ -45,6 +45,7 @@ class BasisSet(StrEnum):
HA_DZ = "haDZ"
HA_TZ = "haTZ"
CBS_ADZ = "cbs(adz)"
STO3G = "sto-3g"
GSTAR = "6-31g*"
CC_PVDZ = "cc-pvdz"
CC_PVTZ = "cc-pvtz"
Expand Down Expand Up @@ -231,6 +232,7 @@ class PotentialMethod(QmMethod): # SPLIT FOR INTERACTIO ENERGIES AND FIX MD1
B1PW91_VWN5_DZP = Functional.B1PW91_VWN5, BasisSet.DZP
B1PW91_VWN5_SZ = Functional.B1PW91_VWN5, BasisSet.SZ
B1PW91_VWN5_TZP = Functional.B1PW91_VWN5, BasisSet.TZP
B3LYP_STO3G = Functional.B3LYP, BasisSet.STO3G # TODO: calculate e0s
B3LYP_VWN5_DZP = Functional.B3LYP_VWN5, BasisSet.DZP
B3LYP_VWN5_SZ = Functional.B3LYP_VWN5, BasisSet.SZ
B3LYP_VWN5_TZP = Functional.B3LYP_VWN5, BasisSet.TZP
Expand Down
Loading

0 comments on commit e1190e3

Please sign in to comment.