diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..5b1f7c7 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @prtos diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..18c9147 --- /dev/null +++ b/.github/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..ee1bbed --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,7 @@ +Checklist: + +- [ ] Was this PR discussed in a issue? It is recommended to first discuss a new feature into a GitHub issue before opening a PR. +- [ ] Add tests to cover the fixed bug(s) or the new introduced feature(s) (if appropriate). +- [ ] Update the API documentation is a new function is added or an existing one is deleted. + +--- diff --git a/.github/SECURITY.md b/.github/SECURITY.md new file mode 100644 index 0000000..c9d0753 --- /dev/null +++ b/.github/SECURITY.md @@ -0,0 +1,3 @@ +# Security Policy + +Please report any security-related issues directly to prudencio@valencediscovery.com. diff --git a/.github/workflows/code-check.yml b/.github/workflows/code-check.yml new file mode 100644 index 0000000..7c4be3b --- /dev/null +++ b/.github/workflows/code-check.yml @@ -0,0 +1,55 @@ +name: code-check + +on: + push: + branches: ["main"] + tags: ["*"] + pull_request: + branches: + - "*" + - "!gh-pages" + +jobs: + python-format-black: + name: Python lint [black] + runs-on: ubuntu-latest + steps: + - name: Checkout the code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Install black + run: | + pip install black>=23 + + - name: Lint + run: black --check . + + python-typing-mypy: + name: Python typing check [mypy] + runs-on: ubuntu-latest + steps: + - name: Checkout the code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Install mypy + run: | + pip install mypy numpy pandas loguru pytest pillow scipy + + - name: Run code check + run: | + mypy . || exitCode=$? + + # only fails if exit code >=2 + if [ $exitCode -ge 2 ]; then + exit $exitCode + fi diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml new file mode 100644 index 0000000..19b96b5 --- /dev/null +++ b/.github/workflows/doc.yml @@ -0,0 +1,48 @@ +name: doc + +on: + push: + branches: ["main"] + +# Prevent doc action on `main` to conflict with each others. +concurrency: + group: doc-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: write + +jobs: + doc: + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + + steps: + - name: Checkout the code + uses: actions/checkout@v3 + + - name: Setup mamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: env.yml + environment-name: openqdc + cache-environment: true + cache-downloads: true + + - name: Install library + run: python -m pip install --no-deps . + + - name: Configure git + run: | + git config --global user.name "${GITHUB_ACTOR}" + git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com" + + - name: Deploy the doc + run: | + echo "Get the gh-pages branch" + git fetch origin gh-pages + + echo "Build and deploy the doc on main" + mike deploy --push --force main diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..966cbdb --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,158 @@ +name: release + +on: + workflow_dispatch: + inputs: + release-version: + description: "A valid Semver version string" + required: true + +permissions: + contents: write + pull-requests: write + +jobs: + release: + # Do not release if not triggered from the default branch + if: github.ref == format('refs/heads/{0}', github.event.repository.default_branch) + + runs-on: ubuntu-latest + timeout-minutes: 30 + + defaults: + run: + shell: bash -l {0} + + steps: + - name: Checkout the code + uses: actions/checkout@v3 + + - name: Setup mamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: env.yml + environment-name: openqdc + cache-environment: true + cache-downloads: true + + - name: Check the version is valid semver + run: | + RELEASE_VERSION="${{ inputs.release-version }}" + + { + pysemver check $RELEASE_VERSION + } || { + echo "The version '$RELEASE_VERSION' is not a valid Semver version string." + echo "Please use a valid semver version string. More details at https://semver.org/" + echo "The release process is aborted." + exit 1 + } + + - name: Check the version is higher than the latest one + run: | + # Retrieve the git tags first + git fetch --prune --unshallow --tags &> /dev/null + + RELEASE_VERSION="${{ inputs.release-version }}" + LATEST_VERSION=$(git describe --abbrev=0 --tags) + + IS_HIGHER_VERSION=$(pysemver compare $RELEASE_VERSION $LATEST_VERSION) + + if [ "$IS_HIGHER_VERSION" != "1" ]; then + echo "The version '$RELEASE_VERSION' is not higher than the latest version '$LATEST_VERSION'." + echo "The release process is aborted." + exit 1 + fi + + - name: Build Changelog + id: github_release + uses: mikepenz/release-changelog-builder-action@v4 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + toTag: "main" + + - name: Configure git + run: | + git config --global user.name "${GITHUB_ACTOR}" + git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com" + + - name: Create and push git tag + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + # Tag the release + git tag -a "${{ inputs.release-version }}" -m "Release version ${{ inputs.release-version }}" + + # Push the modified changelogs + git push origin main + + # Push the tags + git push origin "${{ inputs.release-version }}" + + - name: Install library + run: python -m pip install --no-deps . + + - name: Build the wheel and sdist + run: python -m build --no-isolation + + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API_TOKEN }} + packages-dir: dist/ + + - name: Create GitHub Release + uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 + with: + tag_name: ${{ inputs.release-version }} + body: ${{steps.github_release.outputs.changelog}} + + - name: Deploy the doc + run: | + echo "Get the gh-pages branch" + git fetch origin gh-pages + + echo "Build and deploy the doc on ${{ inputs.release-version }}" + mike deploy --push --force stable + mike deploy --push --force ${{ inputs.release-version }} + + build-installer-linux: + needs: [release] + + runs-on: ubuntu-latest + timeout-minutes: 30 + + defaults: + run: + shell: bash -l {0} + + steps: + - name: Checkout the code + uses: actions/checkout@v3 + with: + ref: ${{ inputs.release-version }} + + - name: Setup mamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: env.yml + environment-name: openqdc + cache-environment: true + cache-downloads: true + + - name: Build the wheel and sdist + run: python -m build --no-isolation + + - name: Build standalone installer + run: | + export OPENQDC_CLIENT_CONSTRUCTOR_VERSION="${{ inputs.release-version }}" + bash ./scripts/build_installer.sh + + - name: Create GitHub Release + uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 + with: + tag_name: ${{ inputs.release-version }} + files: | + ./build/*.sh + ./build/*.sha256 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..997c796 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,54 @@ +name: test + +on: + push: + branches: ["main"] + tags: ["*"] + pull_request: + branches: + - "*" + - "!gh-pages" + schedule: + - cron: "0 4 * * MON" + +jobs: + test: + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10"] + os: ["ubuntu-latest"] + + runs-on: ${{ matrix.os }} + timeout-minutes: 30 + + defaults: + run: + shell: bash -l {0} + + name: | + os=${{ matrix.os }} + - python=${{ matrix.python-version }} + + steps: + - name: Checkout the code + uses: actions/checkout@v3 + + - name: Setup mamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: env.yml + environment-name: openqdc + cache-environment: true + cache-downloads: true + create-args: >- + python=${{ matrix.python-version }} + + - name: Install library + run: python -m pip install --no-deps . + + - name: Run tests + run: pytest + + - name: Test building the doc + run: mkdocs build diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bc693a8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,149 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# PyCharm files +.idea/ +# Rever +rever/ +cache/ + +# Specifically for odd +*.json +*.json.bz2 +*.hdf5 +nohup.out +*.out +*.crt +*.key +*.dat +*.xyz +*.csv +*.txt + diff --git a/.project-root b/.project-root new file mode 100644 index 0000000..2226fe7 --- /dev/null +++ b/.project-root @@ -0,0 +1 @@ +# this file is required for inferring the project root directory diff --git a/README.md b/README.md new file mode 100644 index 0000000..9300362 --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ +# openQDC + +Open Quantum Data Commons + +## Setup Datasets + +Use the scripts in `setup/` to download the datasets. For more information, see the [README](setup/README.md) in the `setup/` directory. + +# Install the library in dev mode +pip install -e . +``` + +## Development lifecycle + +### Tests + +You can run tests locally with: + +```bash +pytest +``` diff --git a/env.yml b/env.yml new file mode 100644 index 0000000..8eb195c --- /dev/null +++ b/env.yml @@ -0,0 +1,69 @@ +channels: + - conda-forge + - pyg # Only for macOS. Remove once https://github.com/conda-forge/pyg-lib-feedstock/pull/14 is merged. + +dependencies: + # standard stuff + - python >=3.8 + - pip + - tqdm + - loguru + - fsspec + - s3fs + - gcsfs + - joblib + - prettytable + - pyrootutils + + # Scientific + - pandas + - numpy + - scipy + - sympy + + # Chem + - ipdb + - datamol #==0.9.0 + - rdkit #-pypi #==2022.9.3 + - ase + + # ML + - e3nn =0.5.1 + - einops =0.6.0 + - pytorch =2.0.0 + - lightning =2.0.4 + - torchmetrics =0.11.4 + - tensorboard =2.11.2 + - umap-learn =0.5.3 + - pytorch_geometric >=2.3.1 + - pytorch_sparse >=0.6.17 + - pytorch_cluster >=1.6 + - pytorch_scatter >=2.1 + - torch-ema + + # other stuffs + - h5py >=3.8.0 + - omegaconf #==2.3.0 + - gdown #==4.6.4 + - hydra-core #==1.3.1 + - wandb #==0.13.10 + + # Viz + - matplotlib + - seaborn + - ipywidgets + - nglview + + # Dev + - pytest >=6.0 + - pytest-cov + - nbconvert + - black >=23 + - jupyterlab + - pre-commit + - ruff + - ipykernel + - pydantic <= 2.0 + + - pip: + - torch-nl diff --git a/openqdc/datasets/ani.py b/openqdc/datasets/ani.py new file mode 100644 index 0000000..6fa6c85 --- /dev/null +++ b/openqdc/datasets/ani.py @@ -0,0 +1,97 @@ +import numpy as np +from tqdm import tqdm +import datamol as dm +from os.path import join as p_join +from openqdc.utils import load_hdf5_file +from openqdc.utils.molecule import get_atom_data +from openqdc.utils.constants import BOHR2ANG, MAX_ATOMIC_NUMBER +from openqdc.datasets.base import BaseDataset + + +def read_record(r, r_name): + n_confs = r["coordinates"].shape[0] + x = r["atomic_numbers"][()] + xs = np.stack((x, np.zeros_like(x)), axis=-1) + positions= r["coordinates"][()] * BOHR2ANG + energies= np.stack([r[k] for k in Ani1.energy_target_names], axis=-1) + forces= np.stack([r[k] for k in Ani1.force_target_names], axis=-1) + + res = dict( + smiles= np.array([r_name]*n_confs), + subset= np.array([r_name]*n_confs), + energies= energies.astype(np.float32), + forces= forces.reshape(-1, *forces.shape[-2:]).astype(np.float32), + atom_data_and_positions = np.concatenate(( + xs[None, ...].repeat(n_confs, axis=0), + positions), axis=-1, dtype=np.float32).reshape(-1, 5), + n_atoms = np.array([x.shape[0]]*n_confs, dtype=np.int32), + ) + + return res + + +class Ani1(BaseDataset): + __name__ = 'ani' + + # Energy in hartree, all zeros by default + atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32) + + __qm_methods__ = [ + "ccsd(t)_cbs", + "hf_dz", + "hf_qz", + "hf_tz", + "mp2_dz", + "mp2_qz", + "mp2_tz", + "npno_ccsd(t)_dz", + "npno_ccsd(t)_tz", + "tpno_ccsd(t)_dz", + "wb97x_dz", + "wb97x_tz", + ] + + energy_target_names = [ + "ccsd(t)_cbs.energy", + "hf_dz.energy", + "hf_qz.energy", + "hf_tz.energy", + "mp2_dz.corr_energy", + "mp2_qz.corr_energy", + "mp2_tz.corr_energy", + "npno_ccsd(t)_dz.corr_energy", + "npno_ccsd(t)_tz.corr_energy", + "tpno_ccsd(t)_dz.corr_energy", + "wb97x_dz.energy", + "wb97x_tz.energy", + ] + + force_target_names = [ + "wb97x_dz.forces", + "wb97x_tz.forces" + ] + + def __init__(self) -> None: + super().__init__() + + def read_raw_entries(self): + raw_path = p_join(self.root, 'ani1.h5') + data = load_hdf5_file(raw_path) + + fn = lambda x: read_record(x[0], x[1]) + tmp = [(data[mol_name], mol_name) for mol_name in data.keys()] + samples = dm.parallelized(fn, tmp, n_jobs=1, progress=True) # don't use more than 1 job + return samples + + +if __name__ == '__main__': + data = Ani1() + n = len(data) + + for i in np.random.choice(n, 100, replace=False): + x = data[i] + for k in x: + print(x.smiles, x.subset, end=' ') + print(k, x[k].shape, end=' ') + + print() \ No newline at end of file diff --git a/openqdc/datasets/base.py b/openqdc/datasets/base.py new file mode 100644 index 0000000..a7180e6 --- /dev/null +++ b/openqdc/datasets/base.py @@ -0,0 +1,141 @@ +import os +import torch +import numpy as np +import pickle as pkl +from os.path import join as p_join +from sklearn.utils import Bunch +from openqdc.utils.paths import get_local_cache +from openqdc.utils.constants import BOHR2ANG, MAX_ATOMIC_NUMBER + + +class BaseDataset(torch.utils.data.Dataset): + __qm_methods__ = [] + + energy_target_names = [] + + force_target_names = [] + + energy_unit = "hartree" + + def __init__(self) -> None: + self.data = None + if not self.is_preprocessed(): + entries = self.read_raw_entries() + res = self.collate_list(entries) + self.save_preprocess(res) + self.read_preprocess() + + @property + def root(self): + return p_join(get_local_cache(), self.__name__) + + @property + def preprocess_path(self): + path = p_join(self.root, 'preprocessed') + os.makedirs(path, exist_ok=True) + return path + + @property + def data_types(self): + return { + "atom_data_and_positions": np.float32, + "position_idx_range": np.int32, + "energies": np.float32, + "forces": np.float32 + } + + @property + def data_shapes(self): + return { + "atom_data_and_positions": (-1, 5), + "position_idx_range": (-1, 2), + "energies": (-1, len(self.energy_target_names)), + "forces": (-1, 3, len(self.force_target_names)) + } + + def read_raw_entries(self): + raise NotImplementedError + + def collate_list(self, list_entries): + # concatenate entries + res = {key: np.concatenate([r[key] for r in list_entries if r is not None], axis=0) + for key in list_entries[0]} + + csum = np.cumsum(res.pop("n_atoms")) + x = np.zeros((csum.shape[0], 2), dtype=np.int32) + x[1:, 0], x[:, 1] = csum[:-1], csum + res["position_idx_range"] = x + return res + + def save_preprocess(self, data_dict): + # save memmaps + for key in self.data_types: + if key not in data_dict: + continue + out = np.memmap(p_join(self.preprocess_path, f"{key}.mmap"), + mode="w+", + dtype=data_dict[key].dtype, + shape=data_dict[key].shape) + out[:] = data_dict.pop(key)[:] + out.flush() + + # save smiles and subset + for key in ["smiles", "subset"]: + uniques, inv_indices = np.unique(data_dict[key], return_inverse=True) + with open(p_join(self.preprocess_path, f"{key}.npz"), "wb") as f: + np.savez_compressed(f, uniques=uniques, inv_indices=inv_indices) + + def read_preprocess(self): + self.data = {} + for key in self.data_types: + filename = p_join(self.preprocess_path, f"{key}.mmap") + if not os.path.exists(filename): + continue + self.data[key] = np.memmap( + filename, mode='r', + dtype=self.data_types[key], + ).reshape(self.data_shapes[key]) + + for key in self.data: + print(f'Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}') + + for key in ["smiles", "subset"]: + filename = p_join(self.preprocess_path, f"{key}.npz") + # with open(filename, "rb") as f: + self.data[key] = np.load(open(filename, "rb")) + for k in self.data[key]: + print(f'Loaded {key}_{k} with shape {self.data[key][k].shape}, dtype {self.data[key][k].dtype}') + + def is_preprocessed(self): + filenames = [p_join(self.preprocess_path, f"{key}.mmap") + for key in self.data_types] + filenames += [p_join(self.preprocess_path, f"{x}.npz") + for x in ["smiles", "subset"]] + return all([os.path.exists(f) for f in filenames]) + + def __len__(self): + return self.data['energies'].shape[0] + + def __getitem__(self, idx: int): + p_start, p_end = self.data["position_idx_range"][idx] + input = self.data["atom_data_and_positions"][p_start:p_end] + z, positions = input[:, 0].astype(np.int32), input[:, 1:] + energies = self.data["energies"][idx] + e0 = self.atomic_energies[z] + smiles = self.data["smiles"]["uniques"][self.data["smiles"]["inv_indices"][idx]] + subset = self.data["smiles"]["uniques"][self.data["subset"]["inv_indices"][idx]] + + if "forces" in self.data: + forces = self.data["forces"][p_start:p_end] + else: + forces = None + + return Bunch( + positions=positions, + atomic_numbers=z, + e0=e0, + energies=energies, + smiles=smiles, + subset=subset, + forces=forces + ) diff --git a/openqdc/datasets/geom.py b/openqdc/datasets/geom.py new file mode 100644 index 0000000..b387a3d --- /dev/null +++ b/openqdc/datasets/geom.py @@ -0,0 +1,103 @@ +import os +import torch +import pickle as pkl +import numpy as np +from tqdm import tqdm +import datamol as dm +from sklearn.utils import Bunch +from os.path import join as p_join +from openqdc.utils import load_pkl, load_json +from openqdc.utils.molecule import get_atom_data +from openqdc.utils.paths import get_local_cache +from openqdc.utils.constants import MAX_ATOMIC_NUMBER +from openqdc.datasets.base import BaseDataset + + +def read_mol(mol_id, mol_dict, base_path, partition): + """ Read molecule from pickle file and return dict with conformers and energies + + Parameters + ---------- + mol_id: str + Unique identifier for the molecule + mol_dict: dict + Dictionary containing the pickle_path and smiles of the molecule + base_path: str + Path to the folder containing the pickle files + + Returns + ------- + res: dict + Dictionary containing the following keys: + - atom_data_and_positions: flatten np.ndarray of shape (M, 4) containing the atomic numbers and positions + - smiles: np.ndarray of shape (N,) containing the smiles of the molecule + - energies: np.ndarray of shape (N,1) containing the energies of the conformers + - n_atoms: np.ndarray of shape (N,) containing the number of atoms in each conformer + """ + + try: + d = load_pkl(p_join(base_path, mol_dict['pickle_path']), False) + confs = d['conformers'] + x = get_atom_data(confs[0]['rd_mol']) + positions = np.array([cf['rd_mol'].GetConformer().GetPositions() for cf in confs]) + n_confs = positions.shape[0] + + res = dict( + atom_data_and_positions = np.concatenate(( + x[None, ...].repeat(n_confs, axis=0), + positions), axis=-1, dtype=np.float32).reshape(-1, 5), + smiles = np.array([d['smiles'] for _ in confs]), + energies = np.array([cf['totalenergy'] for cf in confs], dtype=np.float32)[:, None], + n_atoms = np.array([positions.shape[1]] * n_confs, dtype=np.int32), + subset = np.array([partition] * n_confs), + ) + + except Exception as e: + print (f'Skipping: {mol_id} due to {e}') + res = None + + return res + + +class Geom(BaseDataset): + __name__ = 'geom' + __qm_methods__ = ["gfn2_xtb"] + + energy_target_names = ["gfn2_xtb.energy"] + force_target_names = [] + + # Energy in hartree, all zeros by default + atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32) + + partitions = ['qm9', 'drugs'] + + def __init__(self) -> None: + super().__init__() + + def _read_raw_(self, partition): + raw_path = p_join(self.root, 'rdkit_folder') + + mols = load_json(p_join(raw_path, f'summary_{partition}.json')) + mols = list(mols.items()) + + fn = lambda x: read_mol(x[0], x[1], raw_path, partition) + samples = dm.parallelized(fn, mols, n_jobs=1, progress=True) # don't use more than 1 job + return samples + + def read_raw_entries(self): + samples = sum([self._read_raw_(partition) for partition in self.partitions], []) + return samples + + +if __name__ == '__main__': + data = Geom() + n = len(data) + + for i in np.random.choice(n, 10, replace=False): + x = data[i] + print(x.smiles, x.subset, end=' ') + for k in x: + if k != 'smiles' and k != 'subset': + print(k, x[k].shape if x[k] is not None else None, end=' ') + + print() \ No newline at end of file diff --git a/openqdc/datasets/molecule3d.py b/openqdc/datasets/molecule3d.py new file mode 100644 index 0000000..e1c7cf7 --- /dev/null +++ b/openqdc/datasets/molecule3d.py @@ -0,0 +1,93 @@ +import os +import torch +import pickle as pkl +import numpy as np +import pandas as pd +import os.path as osp +import datamol as dm +from tqdm import tqdm +from glob import glob +from sklearn.utils import Bunch +from rdkit import Chem +from os.path import join as p_join +from openqdc.utils import load_hdf5_file +from openqdc.utils.molecule import get_atom_data +from openqdc.utils.paths import get_local_cache +from openqdc.utils.constants import BOHR2ANG, MAX_ATOMIC_NUMBER +from openqdc.datasets.base import BaseDataset + + +def get_props(df, sdf, idx): + id = sdf.GetItemText(idx).split(" ")[1] + return df.loc[[id]].to_dict(orient="records")[0] + + +def read_mol(mol, props): + smiles = dm.to_smiles(mol, explicit_hs=False) + subset = dm.to_smiles(dm.to_scaffold_murcko(mol, make_generic=True), explicit_hs=False) + x = get_atom_data(mol) + positions= mol.GetConformer().GetPositions() * BOHR2ANG + + res = dict( + smiles= np.array([smiles]), + subset= np.array([subset]), + energies= np.array([props["scf energy"]]).astype(np.float32)[:, None], + atom_data_and_positions = np.concatenate((x, positions), axis=-1, dtype=np.float32), + n_atoms = np.array([x.shape[0]], dtype=np.int32), + ) + + # for key in res: + # print(key, res[key].shape, res[key].dtype) + # exit() + + return res + + +class Molecule3D(BaseDataset): + __name__ = 'molecule3d' + __qm_methods__ = ["b3lyp/6-31g*"] + + energy_target_names = ["b3lyp/6-31g*.energy"] + force_target_names = [] + + # Energy in hartree, all zeros by default + atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32) + + def __init__(self) -> None: + super().__init__() + + def read_raw_entries(self): + raw = p_join(self.root, 'data', 'raw') + sdf_paths = glob(p_join(raw, '*.sdf')) + properties_path = p_join(raw, 'properties.csv') + + properties = pd.read_csv(properties_path, dtype={"cid": str}) + properties.drop_duplicates(subset="cid", inplace=True, keep="first") + properties.set_index("cid", inplace=True) + n = len(sdf_paths) + + tmp = [] + for i, path in enumerate(sdf_paths): + suppl = Chem.SDMolSupplier(path, removeHs=False, sanitize=True) + n = len(suppl) + + tmp += [ + read_mol(suppl[j], get_props(properties, suppl, j)) + for j in tqdm(range(n), desc=f"{i+1}/{n}") + ] + + return tmp + + +if __name__ == '__main__': + data = Molecule3D() + n = len(data) + + for i in np.random.choice(n, 10, replace=False): + x = data[i] + print(x.smiles, x.subset, end=' ') + for k in x: + if k != 'smiles' and k != 'subset': + print(k, x[k].shape if x[k] is not None else None, end=' ') + + print() \ No newline at end of file diff --git a/openqdc/datasets/nabladft.py b/openqdc/datasets/nabladft.py new file mode 100644 index 0000000..f6405c4 --- /dev/null +++ b/openqdc/datasets/nabladft.py @@ -0,0 +1,111 @@ +import os +import torch +import pickle as pkl +import numpy as np +from tqdm import tqdm +import datamol as dm +from sklearn.utils import Bunch +from os.path import join as p_join +from openqdc.utils import load_pkl, load_json +from openqdc.utils.molecule import get_atom_data +from openqdc.utils.paths import get_local_cache +from openqdc.utils.constants import MAX_ATOMIC_NUMBER +from openqdc.datasets.base import BaseDataset + + +def read_mol(mol_id, mol_dict, base_path, partition): + """ Read molecule from pickle file and return dict with conformers and energies + + Parameters + ---------- + mol_id: str + Unique identifier for the molecule + mol_dict: dict + Dictionary containing the pickle_path and smiles of the molecule + base_path: str + Path to the folder containing the pickle files + + Returns + ------- + res: dict + Dictionary containing the following keys: + - atom_data_and_positions: flatten np.ndarray of shape (M, 4) containing the atomic numbers and positions + - smiles: np.ndarray of shape (N,) containing the smiles of the molecule + - energies: np.ndarray of shape (N,1) containing the energies of the conformers + - n_atoms: np.ndarray of shape (N,) containing the number of atoms in each conformer + """ + + try: + d = load_pkl(p_join(base_path, mol_dict['pickle_path']), False) + confs = d['conformers'] + x = get_atom_data(confs[0]['rd_mol']) + positions = np.array([cf['rd_mol'].GetConformer().GetPositions() for cf in confs]) + n_confs = positions.shape[0] + + res = dict( + atom_data_and_positions = np.concatenate(( + x[None, ...].repeat(n_confs, axis=0), + positions), axis=-1, dtype=np.float32).reshape(-1, 5), + smiles = np.array([d['smiles'] for _ in confs]), + energies = np.array([cf['totalenergy'] for cf in confs], dtype=np.float32)[:, None], + n_atoms = np.array([positions.shape[1]] * n_confs, dtype=np.int32), + subset = np.array([partition] * n_confs), + ) + + except Exception as e: + print (f'Skipping: {mol_id} due to {e}') + res = None + + return res + + +class NablaDFT(BaseDataset): + __name__ = 'nabladft' + __qm_methods__ = ["wb97x_svp"] + + energy_target_names = ["wb97x_svp.energy"] + force_target_names = [] + + # Energy in hartree, all zeros by default + atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32) + + partitions = ['qm9', 'drugs'] + + def __init__(self) -> None: + super().__init__() + + def read_raw_entries(self): + raw_path = p_join(self.root, 'nabladft') + + mols = load_json(p_join(raw_path, f'summary_{partition}.json')) + mols = list(mols.items()) + + fn = lambda x: read_mol(x[0], x[1], raw_path, partition) + samples = dm.parallelized(fn, mols, n_jobs=1, progress=True) # don't use more than 1 job + return samples + + +if __name__ == '__main__': + from openqdc.utils.paths import get_local_cache + from nablaDFT.dataset import HamiltonianDatabase + + f_path = p_join(get_local_cache(), "nabladft", "train_2k_energy.db") + f_path = p_join(get_local_cache(), "nabladft", "dataset_train_2k.db") + print(f_path) + train = HamiltonianDatabase(f_path) + Z, R, E, F, H, S, C = train[0] + print(Z.shape, R.shape, E.shape, F.shape, H.shape, S.shape, C.shape) + + + # + # data = NablaDFT() + # n = len(data) + + # for i in np.random.choice(n, 10, replace=False): + # x = data[i] + # print(x.smiles, x.subset, end=' ') + # for k in x: + # if k != 'smiles' and k != 'subset': + # print(k, x[k].shape if x[k] is not None else None, end=' ') + + # print() diff --git a/openqdc/datasets/pcqm.py b/openqdc/datasets/pcqm.py new file mode 100644 index 0000000..e69de29 diff --git a/openqdc/datasets/qmugs.py b/openqdc/datasets/qmugs.py new file mode 100644 index 0000000..324e8a1 --- /dev/null +++ b/openqdc/datasets/qmugs.py @@ -0,0 +1,94 @@ +import os +import pickle as pkl +import numpy as np +import pandas as pd +import os.path as osp +import datamol as dm +from tqdm import tqdm +from glob import glob +from sklearn.utils import Bunch +from rdkit import Chem +from os.path import join as p_join +from openqdc.utils import load_hdf5_file +from openqdc.utils.molecule import get_atom_data +from openqdc.utils.paths import get_local_cache +from openqdc.utils.constants import BOHR2ANG, MAX_ATOMIC_NUMBER +from openqdc.datasets.base import BaseDataset + + +def read_mol(mol_dir): + filenames = glob(p_join(mol_dir, "*.sdf")) + mols = [dm.read_sdf(f)[0] for f in filenames] + n_confs = len(mols) + + if len(mols) == 0: + return None + + smiles = dm.to_smiles(mols[0], explicit_hs=False) + subset = dm.to_smiles(dm.to_scaffold_murcko(mols[0], make_generic=True), explicit_hs=False) + x = get_atom_data(mols[0])[None, ...].repeat(n_confs, axis=0) + positions= np.array([mol.GetConformer().GetPositions() for mol in mols]) + props = [mol.GetPropsAsDict() for mol in mols] + targets = np.array([[p[el]for el in QMugs.energy_target_names] for p in props]) + + res = dict( + smiles= np.array([smiles]*n_confs), + subset= np.array([subset]*n_confs), + energies= targets.astype(np.float32), + atom_data_and_positions = np.concatenate((x, positions), + axis=-1, dtype=np.float32).reshape(-1, 5), + n_atoms = np.array([x.shape[1]]*n_confs, dtype=np.int32), + ) + + # for key in res: + # print(key, res[key].shape, res[key].dtype) + # exit() + + return res + + +class QMugs(BaseDataset): + + __name__ = 'qmugs' + __qm_methods__ = ["b3lyp/6-31g*"] + + energy_target_names = ["GFN2:TOTAL_ENERGY", "DFT:TOTAL_ENERGY",] + # target_names = [ + # "GFN2:TOTAL_ENERGY", + # "GFN2:ATOMIC_ENERGY", + # "GFN2:FORMATION_ENERGY", + # "DFT:TOTAL_ENERGY", + # "DFT:ATOMIC_ENERGY", + # "DFT:FORMATION_ENERGY", + # ] + + force_target_names = [] + + # Energy in hartree, all zeros by default + atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32) + + + def __init__(self) -> None: + super().__init__() + + def read_raw_entries(self): + raw_path = p_join(self.root, 'structures') + mol_dirs = [p_join(raw_path, d) for d in os.listdir(raw_path)] + + tmp = dm.parallelized(read_mol, mol_dirs, n_jobs=-1, + progress=True, scheduler="threads") + return tmp + + +if __name__ == '__main__': + data = QMugs() + n = len(data) + + for i in np.random.choice(n, 10, replace=False): + x = data[i] + print(x.smiles, x.subset, end=' ') + for k in x: + if k != 'smiles' and k != 'subset': + print(k, x[k].shape if x[k] is not None else None, end=' ') + + print() \ No newline at end of file diff --git a/openqdc/datasets/spice.py b/openqdc/datasets/spice.py new file mode 100644 index 0000000..34c5145 --- /dev/null +++ b/openqdc/datasets/spice.py @@ -0,0 +1,100 @@ +import numpy as np +from tqdm import tqdm +import datamol as dm +from os.path import join as p_join +from openqdc.utils import load_hdf5_file +from openqdc.utils.molecule import get_atom_data +from openqdc.utils.constants import BOHR2ANG, MAX_ATOMIC_NUMBER +from openqdc.datasets.base import BaseDataset + + +def read_record(r): + smiles = r["smiles"].asstr()[0] + subset = r["subset"][0].decode("utf-8") + n_confs = r["conformations"].shape[0] + x = get_atom_data(dm.to_mol(smiles, add_hs=True)) + positions= r["conformations"][:] * BOHR2ANG + + res = dict( + smiles= np.array([smiles]*n_confs), + subset= np.array([Spice.subset_mapping[subset]]*n_confs), + energies= r[Spice.energy_target_names[0]][:][:, None].astype(np.float32), + forces= r[Spice.force_target_names[0]][:].reshape(-1, 3, 1) / BOHR2ANG, + atom_data_and_positions = np.concatenate(( + x[None, ...].repeat(n_confs, axis=0), + positions), axis=-1, dtype=np.float32).reshape(-1, 5), + n_atoms = np.array([x.shape[0]]*n_confs, dtype=np.int32), + ) + + return res + + +class Spice(BaseDataset): + __name__ = 'spice' + __qm_methods__ = ["wb97x_tz"] + + energy_target_names = ["dft_total_energy"] + + force_target_names = ["dft_total_gradient"] + + # Energy in hartree, all zeros by default + atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32) + tmp = { + 35: -2574.2451510945853, + 6: -37.91424135791358, + 20: -676.9528465198214, + 17: -460.3350243496703, + 9: -99.91298732343974, + 1: -0.5027370838721259, + 53: -297.8813829975981, + 19: -599.8025677513111, + 3: -7.285254714046546, + 12: -199.2688420040449, + 7: -54.62327513368922, + 11: -162.11366478783253, + 8: -75.17101657391741, + 15: -341.3059197024934, + 16: -398.2405387031612, + } + for key in tmp: + atomic_energies[key] = tmp[key] + + subset_mapping = { + "SPICE Solvated Amino Acids Single Points Dataset v1.1": "Solvated Amino Acids", + "SPICE Dipeptides Single Points Dataset v1.2": "Dipeptides", + "SPICE DES Monomers Single Points Dataset v1.1": "DES370K Monomers", + "SPICE DES370K Single Points Dataset v1.0": "DES370K Dimers", + "SPICE DES370K Single Points Dataset Supplement v1.0": "DES370K Dimers", + "SPICE PubChem Set 1 Single Points Dataset v1.2": "PubChem", + "SPICE PubChem Set 2 Single Points Dataset v1.2": "PubChem", + "SPICE PubChem Set 3 Single Points Dataset v1.2": "PubChem", + "SPICE PubChem Set 4 Single Points Dataset v1.2": "PubChem", + "SPICE PubChem Set 5 Single Points Dataset v1.2": "PubChem", + "SPICE PubChem Set 6 Single Points Dataset v1.2": "PubChem", + "SPICE Ion Pairs Single Points Dataset v1.1": "Ion Pairs", + } + + def __init__(self) -> None: + super().__init__() + + def read_raw_entries(self): + raw_path = p_join(self.root, 'SPICE-1.1.4.hdf5') + + data = load_hdf5_file(raw_path) + tmp = [read_record(data[mol_name]) for mol_name in tqdm(data)] # don't use parallelized here + + return tmp + + +if __name__ == '__main__': + data = Spice() + n = len(data) + + for i in np.random.choice(n, 10, replace=False): + x = data[i] + print(x.smiles, x.subset, end=' ') + for k in x: + if k != 'smiles' and k != 'subset': + print(k, x[k].shape if x[k] is not None else None, end=' ') + + print() \ No newline at end of file diff --git a/openqdc/raws/config_factory.py b/openqdc/raws/config_factory.py new file mode 100644 index 0000000..895aa79 --- /dev/null +++ b/openqdc/raws/config_factory.py @@ -0,0 +1,125 @@ + +class DataConfigFactory: + + ani = dict( + dataset_name="ani", + links={ + "ani1.h5": "https://springernature.figshare.com/ndownloader/files/18112775", + "an1x.hdf5.gz": "https://zenodo.org/record/4081694/files/292.hdf5.gz", + "ani1ccx.hdf5.gz": "https://zenodo.org/record/4081692/files/293.hdf5.gz", + }, + ) + + comp6 = dict( + dataset_name="comp6", + links={ + "gdb7_9.hdf5.gz": "https://zenodo.org/record/3588361/files/208.hdf5.gz", + "gdb10_13.hdf5.gz": "https://zenodo.org/record/3588364/files/209.hdf5.gz", + "drugbank.hdf5.gz": "https://zenodo.org/record/3588361/files/207.hdf5.gz", + "tripeptides.hdf5.gz": "https://zenodo.org/record/3588368/files/211.hdf5.gz", + "ani_md.hdf5.gz": "https://zenodo.org/record/3588341/files/205.hdf5.gz", + "s66x8.hdf5.gz": "https://zenodo.org/record/3588367/files/210.hdf5.gz", + }, + ) + + gdml = dict( + dataset_name="gdml", + links = { + "gdml.hdf5.gz": "https://zenodo.org/record/3585908/files/219.hdf5.gz" + }, + ) + + solvated_peptides = dict( + dataset_name="solvated_peptides", + links = { + "solvated_peptides.hdf5.gz": "https://zenodo.org/record/3585804/files/213.hdf5.gz" + }, + ) + + iso_17 = dict( + dataset_name="iso_17", + links = { + "iso_17.hdf5.gz": "https://zenodo.org/record/3585907/files/216.hdf5.gz" + }, + ) + + sn2_rxn = dict( + dataset_name="sn2_rxn", + links = { + "sn2_rxn.hdf5.gz": "https://zenodo.org/record/3585800/files/212.hdf5.gz" + }, + ) + + # FROM: https://sites.uw.edu/wdbase/database-of-water-clusters/ + waterclusters3_30 = dict( + dataset_name="waterclusters3_30", + links = { + "W3-W30_all_geoms_TTM2.1-F.zip": "https://drive.google.com/file/d/18Y7OiZXSCTsHrQ83GCc4fyE_abbL6E_n" + }, + ) + + geom = dict( + dataset_name="geom", + links = { + "rdkit_folder.tar.gz": "https://dataverse.harvard.edu/api/access/datafile/4327252" + }, + ) + + molecule3d = dict( + dataset_name="molecule3d", + links={ + "molecule3d.zip": "https://drive.google.com/uc?id=1C_KRf8mX-gxny7kL9ACNCEV4ceu_fUGy" + }, + ) + + nabladft = dict( + dataset_name="nabladft", + links={ + "nabladft.db": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_full.db" + }, + ) + + orbnet_denali = dict( + dataset_name="orbnet_denali", + links={ + "orbnet_denali.tar.gz": "https://figshare.com/ndownloader/files/28672287", + "orbnet_denali_targets.tar.gz": "https://figshare.com/ndownloader/files/28672248"}, + ) + + qm7x = dict( + dataset_name="qm7x", + links={f"{i}000.xz":"https://zenodo.org/record/4288677/files/{i}000.xz" for i in range(1, 9) + } + ) + + qmugs = dict( + dataset_name="qmugs", + links={ + "summary.csv": "https://libdrive.ethz.ch/index.php/s/X5vOBNSITAG5vzM/download?path=%2F&files=summary.csv", + "structures.tar.gz": "https://libdrive.ethz.ch/index.php/s/X5vOBNSITAG5vzM/download?path=%2F&files=structures.tar.gz", + }, + ) + + spice = dict( + dataset_name="spice", + links={ + "SPICE-1.1.4.hdf5": "https://zenodo.org/record/8222043/files/SPICE-1.1.4.hdf5" + }, + ) + + pubchemqc = dict( + dataset_name="pubchemqc", + links={ + "pubchemqc.tar.gz": "https://zenodo.org/record/8222043/files/pubchemqc.tar.gz" + }, + ) + + available_datasets = [k for k in locals().keys() if not k.startswith("__")][-1:] + + def __init__(self): + pass + + def __call__(self, dataset_name): + return getattr(self, dataset_name) + + diff --git a/openqdc/raws/fetch.py b/openqdc/raws/fetch.py new file mode 100644 index 0000000..694a680 --- /dev/null +++ b/openqdc/raws/fetch.py @@ -0,0 +1,124 @@ +"""Script to download the molecule3d dataset from Google Drive.""" +import os +import tqdm +import gdown +import fsspec +import socket +import tarfile +import zipfile +import requests +import urllib.error +import urllib.request +from loguru import logger +from sklearn.utils import Bunch +from openqdc.utils.paths import get_local_cache +from openqdc.raws.config_factory import DataConfigFactory + + +# function to download large files with requests +def fetch_file(url, local_filename, overwrite=False): + """ + Download a file from a url to a local file. + Parameters + ---------- + url : str + URL to download from. + local_filename : str + Local file to save to. + overwrite : bool + Whether to overwrite existing files. + Returns + ------- + local_filename : str + Local file. + """ + try: + + if os.path.exists(local_filename) and not overwrite: + logger.info("File already exists, skipping download") + else: + logger.info(f"File: {local_filename}") + if "drive.google.com" in url: + gdown.download(url, local_filename, quiet=False) + else: + r = requests.get(url, stream=True) + with fsspec.open(local_filename, "wb") as f: + for chunk in tqdm.tqdm(r.iter_content(chunk_size=16384)): + if chunk: + f.write(chunk) + + # decompress archive if necessary + parent = os.path.dirname(local_filename) + if local_filename.endswith("tar.gz"): + with tarfile.open(local_filename) as tar: + logger.info(f"Verifying archive extraction states: {local_filename}") + all_names = tar.getnames() + all_extracted = all([os.path.exists(os.path.join(parent, x)) for x in all_names]) + if not all_extracted: + logger.info(f"Extracting archive: {local_filename}") + tar.extractall(path=parent) + else: + logger.info(f"Archive already extracted: {local_filename}") + + elif local_filename.endswith("zip"): + logger.info(f"Verifying archive extraction states: {local_filename}") + with zipfile.ZipFile(local_filename, "r") as zip_ref: + all_names = zip_ref.namelist() + all_extracted = all([os.path.exists(os.path.join(parent, x)) for x in all_names]) + if not all_extracted: + logger.info(f"Extracting archive: {local_filename}") + zip_ref.extractall(parent) + else: + logger.info(f"Archive already extracted: {local_filename}") + + elif local_filename.endswith("xz"): + logger.info(f"Excloabout:blanktracting archive: {local_filename}") + + os.system(f"cd {parent} && xz -d *.xz") + else: + pass + + except (socket.gaierror, urllib.error.URLError) as err: + raise ConnectionError("Could not download {} due to {}".format(url, err)) + + return local_filename + + +class DataDownloader: + """Download data from a remote source. + Parameters + ---------- + cache_path : str + Path to the cache directory. + overwrite : bool + Whether to overwrite existing files. + """ + + def __init__(self, cache_path=None, overwrite=False): + if cache_path is None: + cache_path = get_local_cache() + + self.cache_path = cache_path + self.overwrite = overwrite + + def from_config(self, config: dict): + b_config = Bunch(**config) + data_path = os.path.join(self.cache_path, b_config.dataset_name) + os.makedirs(data_path, exist_ok=True) + + logger.info(f"Downloading the {b_config.dataset_name} dataset") + for local, link in b_config.links.items(): + outfile = os.path.join(data_path, local) + + fetch_file(link, outfile) + + def from_name(self, name): + cfg = DataConfigFactory()(name) + return self.from_config(cfg) + + +if __name__ == "__main__": + for dataset_name in DataConfigFactory.available_datasets: + dd = DataDownloader() + dd.from_name(dataset_name) + diff --git a/openqdc/utils/__init__.py b/openqdc/utils/__init__.py new file mode 100644 index 0000000..92eec25 --- /dev/null +++ b/openqdc/utils/__init__.py @@ -0,0 +1,21 @@ +from .io import ( + check_file, + create_hdf5_file, + load_hdf5_file, + load_json, + load_pkl, + load_torch, + makedirs, + save_pkl, +) + +__all__ = [ + "load_pkl", + "save_pkl", + "makedirs", + "load_hdf5_file", + "load_json", + "load_torch", + "create_hdf5_file", + "check_file", +] diff --git a/openqdc/utils/constants.py b/openqdc/utils/constants.py new file mode 100644 index 0000000..929db32 --- /dev/null +++ b/openqdc/utils/constants.py @@ -0,0 +1,5 @@ +MAX_ATOMIC_NUMBER = 119 + +HAR2EV = 27.211386246 + +BOHR2ANG = 0.52917721092 \ No newline at end of file diff --git a/openqdc/utils/io.py b/openqdc/utils/io.py new file mode 100644 index 0000000..a29d771 --- /dev/null +++ b/openqdc/utils/io.py @@ -0,0 +1,109 @@ +"""IO utilities for mlip package""" +import json +import os +import pickle + +import fsspec +import h5py +import torch +from gcsfs import GCSFileSystem + + +def load_torch_gcs(path): + """Loads torch file""" + # get file system + fs: GCSFileSystem = fsspec.filesystem("gs") + + # load from GCS + with fs.open(path, "rb") as fp: + return torch.load(fp) + + +def load_torch(path): + """Loads torch file""" + return torch.load(path) + + +def makedirs_gcs(path, exist_ok=True): + """Creates directory""" + fs: GCSFileSystem = fsspec.filesystem("gs") + fs.mkdirs(path, exist_ok=exist_ok) + + +def makedirs(path, exist_ok=True): + os.makedirs(path, exist_ok=exist_ok) + + +def check_file(path) -> bool: + """Checks if file present on local""" + return os.path.exists(path) + + +def check_file_gcs(path) -> bool: + """Checks if file present on GCS FileSystem""" + # get file system + fs: GCSFileSystem = fsspec.filesystem("gs") + return fs.exists(path) + + +def save_pkl(file, path): + """Saves pickle file""" + print(f"Saving file at {path}") + with fsspec.open(path, "wb") as fp: # Pickling + pickle.dump(file, fp) + print("Done") + + +def load_pkl_gcs(path, check=True): + """Load pickle file from GCS FileSystem""" + if check: + if not check_file_gcs(path): + raise FileNotFoundError(f"File {path} does not exist on GCS and local.") + + # get file system + fs: GCSFileSystem = fsspec.filesystem("gs") + + with fs.open(path, "rb") as fp: # Unpickling + return pickle.load(fp) + + +def load_pkl(path, check=True): + """Load pickle file""" + if check: + if not check_file(path): + raise FileNotFoundError(f"File {path} does not exist on GCS and local.") + + with open(path, "rb") as fp: # Unpickling + return pickle.load(fp) + + +def load_hdf5_file(hdf5_file_path: str): + """Loads hdf5 file with fsspec""" + if not check_file(hdf5_file_path): + raise FileNotFoundError(f"File {hdf5_file_path} does not exist on GCS and local.") + + fp = fsspec.open(hdf5_file_path, "rb") + if hasattr(fp, "open"): + fp = fp.open() + file = h5py.File(fp) + + # inorder to enable multiprocessing: + # https://github.com/fsspec/gcsfs/issues/379#issuecomment-839929801 + fsspec.asyn.iothread[0] = None + fsspec.asyn.loop[0] = None + + return file + + +def create_hdf5_file(hdf5_file_path: str): + """Creates hdf5 file with fsspec""" + fp = fsspec.open(hdf5_file_path, "wb") + if hasattr(fp, "open"): + fp = fp.open() + return h5py.File(fp, "a") + + +def load_json(path): + """Loads json file""" + with fsspec.open(path, "r") as fp: # Unpickling + return json.load(fp) diff --git a/openqdc/utils/molecule.py b/openqdc/utils/molecule.py new file mode 100644 index 0000000..4890b17 --- /dev/null +++ b/openqdc/utils/molecule.py @@ -0,0 +1,19 @@ +import numpy as np +from rdkit import Chem + + +def get_atomic_number(mol: Chem.Mol): + """Returns atomic numbers for rdkit molecule""" + return np.array([atom.GetAtomicNum() for atom in mol.GetAtoms()]) + + +def get_atomic_charge(mol: Chem.Mol): + """Returns atom charges for rdkit molecule""" + return np.array([atom.GetFormalCharge() for atom in mol.GetAtoms()]) + + +def get_atom_data(mol: Chem.Mol): + """Returns atoms number and charge for rdkit molecule""" + return np.array([[atom.GetAtomicNum(), atom.GetFormalCharge()] + for atom in mol.GetAtoms()]) + diff --git a/openqdc/utils/paths.py b/openqdc/utils/paths.py new file mode 100644 index 0000000..e38ecb4 --- /dev/null +++ b/openqdc/utils/paths.py @@ -0,0 +1,16 @@ +import os + +def get_local_cache(): + fname = os.path.abspath(__file__) + base ='/'.join(fname.split('/')[:-3]) + cache_dir=os.path.join(base, 'cache') + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + +def get_remote_cache(): + fname = os.path.abspath(__file__) + base ='/'.join(fname.split('/')[:-3]) + cache_dir=os.path.join(base, 'cache') + os.makedirs(cache_dir, exist_ok=True) + return cache_dir \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..be78fcf --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,60 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "openqdc" +dynamic = ["version"] +description = "ML ready Quantum Mechanical datasets" +authors = [{ name = "Nikhil Shenoy", email = "nikhilshenoy98@gmail.com" }, + { name = "Prudencio Tossou", email = "tossouprudencio@gmail.com" }] + +[tool.setuptools] +include-package-data = true + +[tool.setuptools_scm] +fallback_version = "dev" + +[tool.isort] +profile = "black" + +[tool.setuptools.packages.find] +where = ["."] +include = ["openqdc", "openqdc.*"] +exclude = [] +namespaces = true + +[tool.pylint.messages_control] +disable = [ + "no-member", + "too-many-arguments", + "too-few-public-methods", + "no-else-return", + "duplicate-code", + "too-many-branches", + "redefined-builtin", + "dangerous-default-value", +] + +[tool.pylint.format] +max-line-length = 120 + +[tool.black] +line-length = 120 +target-version = ['py39', 'py310'] +include = '\.pyi?$' + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-vv" +testpaths = ["tests"] +filterwarnings = [] + +[tool.coverage.run] +omit = ["setup.py", "tests/*"] + +[tool.ruff] +line-length = 120 + +[tool.ruff.per-file-ignores] +"__init__.py" = ["F403"]