Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Units improvements #100

Merged
merged 10 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
os: ["ubuntu-latest"]

runs-on: ${{ matrix.os }}
Expand Down
5 changes: 5 additions & 0 deletions openqdc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def get_project_root():
"SN2RXN": "openqdc.datasets.potential.sn2_rxn",
"QM7X": "openqdc.datasets.potential.qm7x",
"QM7X_V2": "openqdc.datasets.potential.qm7x",
"QM1B": "openqdc.datasets.potential.qm1b",
"QM1B_SMALL": "openqdc.datasets.potential.qm1b",
"NablaDFT": "openqdc.datasets.potential.nabladft",
"SolvatedPeptides": "openqdc.datasets.potential.solvated_peptides",
"WaterClusters": "openqdc.datasets.potential.waterclusters3_30",
Expand Down Expand Up @@ -101,6 +103,8 @@ def __dir__():
from .datasets.interaction.metcalf import Metcalf
from .datasets.interaction.splinter import Splinter
from .datasets.interaction.x40 import X40

# POTENTIAL
from .datasets.potential.ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X, ANI2X
from .datasets.potential.comp6 import COMP6
from .datasets.potential.dummy import Dummy
Expand All @@ -113,6 +117,7 @@ def __dir__():
from .datasets.potential.nabladft import NablaDFT
from .datasets.potential.orbnet_denali import OrbnetDenali
from .datasets.potential.pcqm import PCQM_B3LYP, PCQM_PM6
from .datasets.potential.qm1b import QM1B, QM1B_SMALL
from .datasets.potential.qm7x import QM7X, QM7X_V2
from .datasets.potential.qmugs import QMugs, QMugs_V2
from .datasets.potential.revmd17 import RevMD17
Expand Down
42 changes: 26 additions & 16 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
)
from openqdc.utils.package_utils import has_package, requires_package
from openqdc.utils.regressor import Regressor # noqa
from openqdc.utils.units import get_conversion
from openqdc.utils.units import (
DistanceTypeConversion,
EnergyTypeConversion,
ForceTypeConversion,
get_conversion,
)

if has_package("torch"):
import torch
Expand Down Expand Up @@ -129,7 +134,7 @@ def __init__(
set_cache_dir(cache_dir)
# self._init_lambda_fn()
self.data = None
self._original_unit = self.__energy_unit__
self._original_unit = self.energy_unit
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
self.transform = transform
Expand Down Expand Up @@ -225,24 +230,27 @@ def e0s_dispatcher(self):
def _convert_data(self):
logger.info(
f"Converting {self.__name__} data to the following units:\n\
Energy: {self.energy_unit},\n\
Distance: {self.distance_unit},\n\
Forces: {self.force_unit if self.__force_methods__ else 'None'}"
Energy: {str(self.energy_unit)},\n\
Distance: {str(self.distance_unit)},\n\
Forces: {str(self.force_unit) if self.__force_methods__ else 'None'}"
)
for key in self.data_keys:
self.data[key] = self._convert_on_loading(self.data[key], key)

@property
def energy_unit(self):
return self.__energy_unit__
return EnergyTypeConversion(self.__energy_unit__)

@property
def distance_unit(self):
return self.__distance_unit__
return DistanceTypeConversion(self.__distance_unit__)

@property
def force_unit(self):
return self.__forces_unit__
units = self.__forces_unit__.split("/")
if len(units) > 2:
units = ["/".join(units[:2]), units[-1]]
return ForceTypeConversion(tuple(units)) # < 3.12 compatibility

@property
def root(self):
Expand Down Expand Up @@ -291,15 +299,15 @@ def data_shapes(self):
"forces": (-1, 3, len(self.force_methods)),
}

def _set_units(self, en, ds):
def _set_units(self, en: Optional[str] = None, ds: Optional[str] = None):
old_en, old_ds = self.energy_unit, self.distance_unit
en = en if en is not None else old_en
ds = ds if ds is not None else old_ds
self.set_energy_unit(en)
self.set_distance_unit(ds)
if self.__force_methods__:
self.__forces_unit__ = self.energy_unit + "/" + self.distance_unit
self._fn_forces = get_conversion(old_en + "/" + old_ds, self.__forces_unit__)
self._fn_forces = self.force_unit.to(str(self.energy_unit), str(self.distance_unit))
self.__forces_unit__ = str(self.energy_unit) + "/" + str(self.distance_unit)

def _set_isolated_atom_energies(self):
if self.__energy_methods__ is None:
Expand All @@ -308,7 +316,7 @@ def _set_isolated_atom_energies(self):
f = get_conversion("hartree", self.__energy_unit__)
else:
# regression are calculated on the original unit of the dataset
f = get_conversion(self._original_unit, self.__energy_unit__)
f = self._original_unit.to(self.energy_unit)
self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix)

def convert_energy(self, x):
Expand All @@ -324,17 +332,19 @@ def set_energy_unit(self, value: str):
"""
Set a new energy unit for the dataset.
"""
old_unit = self.energy_unit
# old_unit = self.energy_unit
# self.__energy_unit__ = value
self._fn_energy = self.energy_unit.to(value) # get_conversion(old_unit, value)
self.__energy_unit__ = value
self._fn_energy = get_conversion(old_unit, value)

def set_distance_unit(self, value: str):
"""
Set a new distance unit for the dataset.
"""
old_unit = self.distance_unit
# old_unit = self.distance_unit
# self.__distance_unit__ = value
self._fn_distance = self.distance_unit.to(value) # get_conversion(old_unit, value)
self.__distance_unit__ = value
self._fn_distance = get_conversion(old_unit, value)

def set_array_format(self, format: str):
assert format in ["numpy", "torch", "jax"], f"Format {format} not supported."
Expand Down
2 changes: 1 addition & 1 deletion openqdc/datasets/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(
self.recompute_statistics = True
self.refit_e0s = True
self.energy_type = energy_type
self._original_unit = energy_unit
self.__energy_unit__ = energy_unit
self._original_unit = self.energy_unit
self.__distance_unit__ = distance_unit
self.__energy_methods__ = [PotentialMethod.NONE if not level_of_theory else level_of_theory]
self.energy_target_names = ["xyz"]
Expand Down
3 changes: 3 additions & 0 deletions openqdc/datasets/potential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .nabladft import NablaDFT
from .orbnet_denali import OrbnetDenali
from .pcqm import PCQM_B3LYP, PCQM_PM6
from .qm1b import QM1B, QM1B_SMALL
from .qm7x import QM7X, QM7X_V2
from .qmugs import QMugs, QMugs_V2
from .revmd17 import RevMD17
Expand Down Expand Up @@ -39,6 +40,8 @@
"QM7X_V2": QM7X_V2,
"QMugs": QMugs,
"QMugs_V2": QMugs_V2,
"QM1B": QM1B,
"QM1B_SMALL": QM1B_SMALL,
"SN2RXN": SN2RXN,
"SolvatedPeptides": SolvatedPeptides,
"Spice": Spice,
Expand Down
16 changes: 15 additions & 1 deletion openqdc/datasets/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,21 @@ class ANI1CCX_V2(ANI1CCX):


class ANI2X(ANI1):
""" """
"""
The ANI-2X dataset was constructed using active learning from modified versions of GDB-11, CheMBL,
and s66x8. It adds three new elements (F, Cl, S) resulting in 4.6 million conformers from 13k
chemical isomers, optimized using the LBFGS algorithm and labeled with ωB97X/6-31G*.

Usage
```python
from openqdc.datasets import ANI@X
dataset = ANI2X()
```

References:
- ANI-2x: https://doi.org/10.1021/acs.jctc.0c00121
- Github: https://github.com/aiqm/ANI1x_datasets
"""

__name__ = "ani2x"
__energy_unit__ = "hartree"
Expand Down
157 changes: 157 additions & 0 deletions openqdc/datasets/potential/qm1b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import os
from functools import partial
from os.path import join as p_join

import datamol as dm
import numpy as np
import pandas as pd

from openqdc.datasets.base import BaseDataset
from openqdc.methods import PotentialMethod
from openqdc.utils.io import get_local_cache

# fmt: off
FILE_NUM = [
"43005175","43005205","43005208","43005211","43005214","43005223",
"43005235","43005241","43005244","43005247","43005253","43005259",
"43005265","43005268","43005271","43005274","43005277","43005280",
"43005286","43005292","43005298","43005304","43005307","43005313",
"43005319","43005322","43005325","43005331","43005337","43005343"
"43005346","43005349","43005352","43005355","43005358","43005364",
"43005370","43005406","43005409","43005415","43005418","43005421",
"43005424","43005427","43005430","43005433","43005436","43005439",
"43005442","43005448","43005454","43005457","43005460","43005463",
"43005466","43005469","43005472","43005475","43005478","43005481",
"43005484","43005487","43005490","43005496","43005499","43005502",
"43005505","43005508","43005511","43005514","43005517","43005520",
"43005523","43005526","43005532","43005538","43005544","43005547",
"43005550","43005553","43005556","43005559","43005562","43005577",
"43005580","43005583","43005589","43005592","43005598","43005601",
"43005616","43005622","43005625","43005628","43005634","43005637",
"43005646","43005649","43005658","43005661","43005676","43006159",
"43006162","43006165","43006168","43006171","43006174","43006177",
"43006180","43006186","43006207","43006210","43006213","43006219",
"43006222","43006228","43006231","43006273","43006276","43006279",
"43006282","43006288","43006294","43006303","43006318","43006324",
"43006330","43006333","43006336","43006345","43006354","43006372",
"43006381","43006384","43006390","43006396","43006405","43006408",
"43006411","43006423","43006432","43006456","43006468","43006471",
"43006477","43006486","43006489","43006492","43006498","43006501",
"43006513","43006516","43006522","43006525","43006528","43006531",
"43006534","43006537","43006543","43006546","43006576","43006579",
"43006603","43006609","43006615","43006621","43006624","43006627",
"43006630","43006633","43006639","43006645","43006651","43006654",
"43006660","43006663","43006666","43006669","43006672","43006681",
"43006690","43006696","43006699","43006711","43006717","43006738",
"43006747","43006756","43006762","43006765","43006768","43006771",
"43006777","43006780","43006786","43006789","43006795","43006798",
"43006801","43006804","43006816","43006822","43006837","43006840",
"43006846","43006855","43006858","43006861","43006864","43006867",
"43006870","43006873","43006876","43006882","43006897","43006900",
"43006903","43006909","43006912","43006927","43006930","43006933",
"43006939","43006942","43006948","43006951","43006954","43006957",
"43006966","43006969","43006978","43006984","43006993","43006996",
"43006999","43007002","43007005","43007008","43007011","43007014",
"43007017","43007032","43007035","43007041","43007044","43007047",
"43007050","43007053","43007056","43007062","43007068","43007080",
"43007098","43007110","43007119","43007122","43007125",
]
# fmt: on


def extract_from_row(row, file_idx=None):
smiles = row["smile"]
z = np.array(row["z"])[:, None]
c = np.zeros_like(z)
x = np.concatenate((z, c), axis=1)
positions = np.array(row["pos"]).reshape(-1, 3)

res = dict(
name=np.array([smiles]),
subset=np.array(["qm1b"]) if file_idx is None else np.array([f"qm1b_{file_idx}"]),
energies=np.array([row["energy"]]).astype(np.float64)[:, None],
atomic_inputs=np.concatenate((x, positions), axis=-1, dtype=np.float32),
n_atoms=np.array([x.shape[0]], dtype=np.int32),
)
return res


class QM1B(BaseDataset):
"""
QM1B is a low-resolution DFT dataset generated using PySCF IPU.
It is composed of one billion training examples containing 9-11 heavy atoms.
It was created by taking 1.09M SMILES strings from the GDB-11 database and
computing molecular properties (e.g. HOMO-LUMO gap) for a set of up to 1000
conformers per molecule at the B3LYP/STO-3G level of theory.

Usage:
```python
from openqdc.datasets import QM1B
dataset = QM1B()
```

References:
- https://arxiv.org/pdf/2311.01135
- https://github.com/graphcore-research/qm1b-dataset/
"""

__name__ = "qm1b"

__energy_methods__ = [PotentialMethod.B3LYP_STO3G]
__force_methods__ = []

energy_target_names = ["b3lyp/sto-3g"]
force_target_names = []

__energy_unit__ = "ev"
__distance_unit__ = "bohr"
__forces_unit__ = "ev/bohr"
__links__ = {
"qm1b_validation.parquet": "https://ndownloader.figshare.com/files/43005175",
**{f"part_{i:03d}.parquet": f"https://ndownloader.figshare.com/files/{FILE_NUM[i]}" for i in range(0, 256)},
}

@property
def root(self):
return p_join(get_local_cache(), "qm1b")

@property
def preprocess_path(self):
path = p_join(self.root, "preprocessed", self.__name__)
os.makedirs(path, exist_ok=True)
return path

def read_raw_entries(self):
filenames = list(map(lambda x: p_join(self.root, f"part_{x:03d}.parquet"), list(range(0, 256)))) + [
p_join(self.root, "qm1b_validation.parquet")
]

def read_entries_parallel(filename):
df = pd.read_parquet(filename)

def extract_parallel(df, i):
return extract_from_row(df.iloc[i])

fn = partial(extract_parallel, df)
list_of_idxs = list(range(len(df)))
results = dm.utils.parallelized(fn, list_of_idxs, scheduler="threads", progress=False)
return results

list_of_list = dm.utils.parallelized(read_entries_parallel, filenames, scheduler="processes", progress=True)

return [x for xs in list_of_list for x in xs]


class QM1B_SMALL(QM1B):
"""
QM1B_SMALL is a subset of the QM1B dataset containing a
maximum of 15 random conformers per molecule.

Usage:
```python
from openqdc.datasets import QM1B_SMALL
dataset = QM1B_SMALL()
```
"""

__name__ = "qm1b_small"
4 changes: 3 additions & 1 deletion openqdc/methods/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class StrEnum(str, Enum):
def __str__(self):
return self.value
return self.value.lower()


@unique
Expand Down Expand Up @@ -45,6 +45,7 @@ class BasisSet(StrEnum):
HA_DZ = "haDZ"
HA_TZ = "haTZ"
CBS_ADZ = "cbs(adz)"
STO3G = "sto-3g"
GSTAR = "6-31g*"
CC_PVDZ = "cc-pvdz"
CC_PVTZ = "cc-pvtz"
Expand Down Expand Up @@ -231,6 +232,7 @@ class PotentialMethod(QmMethod): # SPLIT FOR INTERACTIO ENERGIES AND FIX MD1
B1PW91_VWN5_DZP = Functional.B1PW91_VWN5, BasisSet.DZP
B1PW91_VWN5_SZ = Functional.B1PW91_VWN5, BasisSet.SZ
B1PW91_VWN5_TZP = Functional.B1PW91_VWN5, BasisSet.TZP
B3LYP_STO3G = Functional.B3LYP, BasisSet.STO3G # TODO: calculate e0s
B3LYP_VWN5_DZP = Functional.B3LYP_VWN5, BasisSet.DZP
B3LYP_VWN5_SZ = Functional.B3LYP_VWN5, BasisSet.SZ
B3LYP_VWN5_TZP = Functional.B3LYP_VWN5, BasisSet.TZP
Expand Down
Loading
Loading