Skip to content

Commit

Permalink
Docstrings+type hinting of various methods
Browse files Browse the repository at this point in the history
  • Loading branch information
FNTwin committed Jul 10, 2024
1 parent dba3607 commit e44281f
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 76 deletions.
80 changes: 53 additions & 27 deletions openqdc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,34 @@


def sanitize(dictionary):
"""
Sanitize dataset names to be used in the CLI.
"""
return {k.lower().replace("_", "").replace("-", ""): v for k, v in dictionary.items()}


SANITIZED_AVAILABLE_DATASETS = sanitize(AVAILABLE_DATASETS)


def exist_dataset(dataset):
def exist_dataset(dataset) -> bool:
"""
Check if dataset is available in the openQDC datasets.
"""
if dataset not in sanitize(AVAILABLE_DATASETS):
logger.error(f"{dataset} is not available. Please open an issue on Github for the team to look into it.")
return False
return True


def format_entry(empty_dataset):
def format_entry(empty_dataset, max_num_to_display: int = 6):
"""
Format the entry for the table.
max_num_to_display: int = 6,
Maximum number of energy methods to display. Used to keep the table format
readable in case of datasets with many energy methods. [ex. MultiXQM9]
"""
energy_methods = [str(x) for x in empty_dataset.__energy_methods__]
max_num_to_display = 6

if len(energy_methods) > 6:
entry = ",".join(energy_methods[:max_num_to_display]) + "..."
else:
Expand All @@ -48,7 +60,7 @@ def download(
overwrite: Annotated[
bool,
typer.Option(
help="Whether to overwrite or force the re-download of the datasets.",
help="Whether to force the re-download of the datasets and overwrite the current cached dataset.",
),
] = False,
cache_dir: Annotated[
Expand All @@ -60,13 +72,14 @@ def download(
as_zarr: Annotated[
bool,
typer.Option(
help="Whether to overwrite or force the re-download of the datasets.",
help="Whether to use a zarr format for the datasets instead of memmap.",
),
] = False,
gs: Annotated[
bool,
typer.Option(
help="Whether to use gs to re-download of the datasets.",
help="Whether source to use for downloading. If True, Google Storage will be used."
+ "Otherwise, AWS S3 will be used",
),
] = False,
):
Expand All @@ -78,6 +91,7 @@ def download(
"""
if gs:
os.environ["OPENQDC_DOWNLOAD_API"] = "gs"

for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
if exist_dataset(dataset):
ds = SANITIZED_AVAILABLE_DATASETS[dataset].no_init()
Expand All @@ -93,7 +107,7 @@ def download(
@app.command()
def datasets():
"""
Print a table of the available openQDC datasets and some informations.
Print a formatted table of the available openQDC datasets and some informations.
"""
table = PrettyTable(["Name", "Type of Energy", "Forces", "Level of theory"])
for dataset in AVAILABLE_DATASETS:
Expand All @@ -118,7 +132,7 @@ def fetch(
overwrite: Annotated[
bool,
typer.Option(
help="Whether to overwrite or force the re-download of the files.",
help="Whether to overwrite or force the re-download of the raw files.",
),
] = False,
cache_dir: Annotated[
Expand All @@ -129,17 +143,14 @@ def fetch(
] = None,
):
"""
Download the raw datasets files from the main openQDC hub.
overwrite: bool = False,
If True, the files will be re-downloaded and overwritten.
cache_dir: Optional[str] = None,
Path to the cache. If not provided, the default cache directory will be used.
Special case: if the dataset is "all", "potential", "interaction".
all: all available datasets will be downloaded.
potential: all the potential datasets will be downloaded
interaction: all the interaction datasets will be downloaded
Example:
openqdc fetch Spice
Download the raw datasets files from the main openQDC hub.\n
Special case: if the dataset is "all", "potential", "interaction".\n
all: all available datasets will be downloaded.\n
potential: all the potential datasets will be downloaded\n
interaction: all the interaction datasets will be downloaded\n\n
Example:\n
openqdc fetch Spice
"""
if datasets[0].lower() == "all":
dataset_names = list(sanitize(AVAILABLE_DATASETS).keys())
Expand All @@ -163,18 +174,27 @@ def preprocess(
overwrite: Annotated[
bool,
typer.Option(
help="Whether to overwrite or force the re-download of the datasets.",
help="Whether to overwrite the current cached datasets.",
),
] = True,
upload: Annotated[
bool,
typer.Option(
help="Whether to try the upload to the remote storage.",
help="Whether to attempt the upload to the remote storage. Must have write permissions.",
),
] = False,
as_zarr: Annotated[
bool,
typer.Option(
help="Whether to preprocess as a zarr format or a memmap format.",
),
] = False,
):
"""
Preprocess a raw dataset (previously fetched) into a openqdc dataset and optionally push it to remote.
Example:
openqdc preprocess Spice QMugs
"""
for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
if exist_dataset(dataset):
Expand All @@ -192,7 +212,7 @@ def upload(
overwrite: Annotated[
bool,
typer.Option(
help="Whether to overwrite or force the re-download of the datasets.",
help="Whether to overwrite the remote files if they are present.",
),
] = True,
as_zarr: Annotated[
Expand All @@ -204,6 +224,9 @@ def upload(
):
"""
Upload a preprocessed dataset to the remote storage.
Example:
openqdc upload Spice --overwrite
"""
for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
if exist_dataset(dataset):
Expand All @@ -216,23 +239,23 @@ def upload(


@app.command()
def convert_to_zarr(
def convert(
datasets: List[str],
overwrite: Annotated[
bool,
typer.Option(
help="Whether to overwrite or force the re-download of the datasets.",
help="Whether to overwrite the current zarr cached datasets.",
),
] = False,
download: Annotated[
bool,
typer.Option(
help="Whether to force the re-download of the datasets.",
help="Whether to force the re-download of the memmap datasets.",
),
] = False,
):
"""
Convert a preprocessed dataset to the zarr file format.
Convert a preprocessed dataset from a memmap dataset to a zarr dataset.
"""
import os
from os.path import join as p_join
Expand All @@ -243,6 +266,9 @@ def convert_to_zarr(
from openqdc.utils.io import load_pkl

def silent_remove(filename):
"""
Zarr zip files are currently not overwritable. This function is used to remove the file if it exists.
"""
try:
os.remove(filename)
except OSError:
Expand Down Expand Up @@ -305,7 +331,7 @@ def silent_remove(filename):


@app.command()
def show_cache():
def cache():
"""
Get the current local cache path of openQDC
"""
Expand Down
2 changes: 1 addition & 1 deletion openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def set_array_format(self, format: str):
def read_raw_entries(self):
raise NotImplementedError

def collate_list(self, list_entries):
def collate_list(self, list_entries: List[Dict]):
# concatenate entries
res = {key: np.concatenate([r[key] for r in list_entries if r is not None], axis=0) for key in list_entries[0]}

Expand Down
7 changes: 5 additions & 2 deletions openqdc/datasets/dataset_structure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pickle as pkl
from abc import ABC, abstractmethod
from os.path import join as p_join
from typing import List, Optional
from typing import Callable, List, Optional

import numpy as np
import zarr
Expand All @@ -23,7 +23,10 @@ def ext(self):

@property
@abstractmethod
def load_fn(self):
def load_fn(self) -> Callable:
"""
Function to use for loading the data.
"""
raise NotImplementedError

def add_extension(self, filename):
Expand Down
61 changes: 43 additions & 18 deletions openqdc/datasets/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def try_retrieve(obj, callable, default=None):


class FromFileDataset(BaseDataset, ABC):
"""Abstract class for datasets that read from a common format file like xzy, netcdf, gro, hdf5, etc."""

def __init__(
self,
path: List[str],
Expand All @@ -35,12 +37,30 @@ def __init__(
},
):
"""
Create a dataset from a xyz file.
Create a dataset from a list of files.
Parameters
----------
path : List[str]
The path to the file or a list of paths.
dataset_name : Optional[str], optional
The name of the dataset, by default None.
energy_type : Optional[str], optional
The type of isolated atom energy by default "regression".
Supported types: ["formation", "regression", "null", None]
energy_unit
Energy unit of the dataset. Default is "hartree".
distance_unit
Distance unit of the dataset. Default is "ang".
level_of_theory: Optional[QmMethod, str]
The level of theory of the dataset.
Used if energy_type is "formation" to fetch the correct isolated atom energies.
transform, optional
transformation to apply to the __getitem__ calls
regressor_kwargs
Dictionary of keyword arguments to pass to the regressor.
Default: {"solver_type": "linear", "sub_sample": None, "stride": 1}
solver_type can be one of ["linear", "ridge"]
"""
self.path = [path] if isinstance(path, str) else path
self.__name__ = self.__class__.__name__ if dataset_name is None else dataset_name
Expand All @@ -62,29 +82,19 @@ def __init__(
self.set_array_format(array_format)
self._post_init(True, energy_unit, distance_unit)

def __str__(self):
return self.__name__.lower()

def __repr__(self):
return str(self)

@abstractmethod
def read_as_atoms(self, path: str) -> List[Atoms]:
"""
Method that reads a path and return a list of Atoms objects.
Method that reads a file and return a list of Atoms objects.
path : str
The path to the file.
"""
raise NotImplementedError

def collate_list(self, list_entries):
res = {key: np.concatenate([r[key] for r in list_entries if r is not None], axis=0) for key in list_entries[0]}
csum = np.cumsum(res.get("n_atoms"))
x = np.zeros((csum.shape[0], 2), dtype=np.int32)
x[1:, 0], x[:, 1] = csum[:-1], csum
res["position_idx_range"] = x

return res

def read_raw_entries(self):
def read_raw_entries(self) -> List[dict]:
"""
Process the files and return a list of data objects.
"""
entries_list = []
for path in self.path:
for entry in self.read_as_atoms(path):
Expand All @@ -96,6 +106,11 @@ def _read_and_preprocess(self):
self.data = self.collate_list(entries_list)

def _convert_to_record(self, obj: Atoms):
"""
Convert an Atoms object to a record for the openQDC dataset processing.
obj : Atoms
The ase.Atoms object to convert
"""
name = obj.info.get("name", None)
subset = obj.info.get("subset", str(self))
positions = obj.positions
Expand All @@ -116,8 +131,18 @@ def _convert_to_record(self, obj: Atoms):
n_atoms=np.array([len(positions)], dtype=np.int32),
)

def __str__(self):
return self.__name__.lower()

def __repr__(self):
return str(self)


class XYZDataset(FromFileDataset):
"""
Baseclass to read datasets from xyz and extxyz files.
"""

def read_as_atoms(self, path):
from ase.io import iread

Expand Down
5 changes: 5 additions & 0 deletions openqdc/datasets/potential/alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def read_mol(file, energy):


class Alchemy(BaseDataset):
"""
https://alchemy.tencent.com/
https://arxiv.org/abs/1906.09427
"""

__name__ = "alchemy"

__energy_methods__ = [
Expand Down
6 changes: 5 additions & 1 deletion openqdc/datasets/potential/proteinfragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _unpack_data_tuple(self, data):

# graphs is smiles
class ProteinFragments(BaseDataset):
""" """
"""https://www.science.org/doi/10.1126/sciadv.adn4397"""

__name__ = "proteinfragments"

Expand Down Expand Up @@ -134,6 +134,10 @@ def read_raw_entries(self):


class MDDataset(ProteinFragments):
"""
Part of the proteinfragments dataset that is generated from the molecular dynamics with their model.
"""

__name__ = "mddataset"

__links__ = {
Expand Down
Loading

0 comments on commit e44281f

Please sign in to comment.