diff --git a/.github/workflows/check_pypi_build.yml b/.github/workflows/check_pypi_build.yml deleted file mode 100644 index a761d8c2..00000000 --- a/.github/workflows/check_pypi_build.yml +++ /dev/null @@ -1,27 +0,0 @@ -name: Check PyPI Build -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - - workflow_dispatch: - -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Setup python - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: x64 - - name: Install Dependencies - run: | - python -m pip install -e . - python -m pip install build - - - name: Check Errors - run: | - python -m build \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..a9b022fa --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,164 @@ +name: Continuous Integration +on: + schedule: + - cron: "0 8 * * 1-5" + push: + branches: [main] + pull_request: + branches: [main] + workflow_dispatch: + +concurrency: + group: actions-id-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + check-formatting: + name: Check Build and Formatting Errors + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install Dependencies + run: | + python -m pip install pycodestyle isort + + - name: Check Build + run: | + python -m pip install . + + - name: Run pycodestyle + run: | + pycodestyle --statistics --count --max-line-length=150 --show-source --ignore=E203 . + + - name: Check Import Ordering Errors + run: | + isort --check-only --verbose . + + build-and-test: + needs: check-formatting + continue-on-error: true + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + os: [ubuntu-latest, windows-latest, macos-latest] + + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -el {0} + name: ${{ matrix.os }} Python ${{ matrix.python-version }} Subtest + steps: + - uses: actions/checkout@v3 + - uses: mamba-org/setup-micromamba@main + with: + environment-name: temp + condarc: | + channels: + - defaults + - conda-forge + channel_priority: flexible + create-args: | + python=${{ matrix.python-version }} + - name: Install Dependencies + run: | + python -m pip install -e .[molecules] + python -m pip install coverage pytest + - name: Run Tests + run: | + coverage run --source=. --omit=astartes/__init__.py,setup.py,test/* -m pytest -v + - name: Show Coverage + run: | + coverage report -m + + ipynb-ci: + needs: check-formatting + strategy: + fail-fast: false + matrix: + nb-file: + ["barrier_prediction_with_RDB7/RDB7_barrier_prediction_example", "train_val_test_split_sklearn_example/train_val_test_split_example", "split_comparisons/split_comparisons", "mlpds_2023_astartes_demonstration/mlpds_2023_demo"] + runs-on: ubuntu-latest + defaults: + run: + shell: bash -el {0} + name: Check ${{ matrix.nb-file }} Notebook Execution + steps: + - uses: actions/checkout@v3 + - uses: mamba-org/setup-micromamba@main + with: + environment-name: temp + condarc: | + channels: + - defaults + - conda-forge + channel_priority: flexible + create-args: | + python=3.11 + - name: Install dependencies + run: | + python -m pip install -e .[molecules,demos] + python -m pip install notebook + - name: Test Execution + run: | + cd examples/$(dirname ${{ matrix.nb-file }}) + jupyter nbconvert --to script $(basename ${{ matrix.nb-file }}).ipynb + ipython $(basename ${{ matrix.nb-file }}).py + + coverage-check: + if: contains(github.event.pull_request.labels.*.name, 'PR Ready for Review') + needs: [build-and-test, ipynb-ci] + runs-on: ubuntu-latest + defaults: + run: + shell: bash -el {0} + steps: + - uses: actions/checkout@v3 + - uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + python-version: "3.10" + - name: Install Dependencies + run: | + python -m pip install -e .[molecules] + python -m pip install coverage + - name: Run Tests + run: | + coverage run --source=. --omit=astartes/__init__.py,setup.py,test/*,astartes/samplers/sampler.py -m unittest discover -v + - name: Show Coverage + run: | + coverage report -m > temp.txt + cat temp.txt + python .github/workflows/coverage_helper.py + echo "COVERAGE_PERCENT=$(cat temp2.txt)" >> $GITHUB_ENV + + - name: Request Changes via Review + if: ${{ env.COVERAGE_PERCENT < 90 }} + uses: andrewmusgrave/automatic-pull-request-review@0.0.5 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + event: REQUEST_CHANGES + body: "Increase test coverage from ${{ env.COVERAGE_PERCENT }}% to at least 90% before merging." + + - name: Approve PR if Coverage Sufficient + if: ${{ env.COVERAGE_PERCENT > 89 }} + uses: andrewmusgrave/automatic-pull-request-review@0.0.5 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + event: APPROVE + body: "Test coverage meets or exceeds 90% threshold (currently ${{ env.COVERAGE_PERCENT }}%)." + + ci-report-status: + name: report CI status + needs: [build-and-test, ipynb-ci] + runs-on: ubuntu-latest + steps: + - run: | + result_1="${{ needs.build-and-test.result }}" + result_2="${{ needs.ipynb-ci.result }}" + if test $result_1 == "success" && test $result_2 == "success"; then + exit 0 + else + exit 1 + fi + \ No newline at end of file diff --git a/.github/workflows/coverage_reject.yml b/.github/workflows/coverage_reject.yml deleted file mode 100644 index 253ad49e..00000000 --- a/.github/workflows/coverage_reject.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: Ensure Sufficient Coverage -on: - pull_request: - branches: [main] - - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -jobs: - build: - if: contains(github.event.pull_request.labels.*.name, 'PR Ready for Review') - runs-on: ubuntu-latest - defaults: - run: - shell: bash -el {0} - steps: - - uses: actions/checkout@v3 - - uses: conda-incubator/setup-miniconda@v2 - with: - auto-update-conda: true - python-version: "3.10" - - name: Install Dependencies - run: | - python -m pip install -e .[molecules] - python -m pip install coverage - - name: Run Tests - run: | - coverage run --source=. --omit=astartes/__init__.py,setup.py,test/*,astartes/samplers/sampler.py -m unittest discover -v - - name: Show Coverage - run: | - coverage report -m > temp.txt - cat temp.txt - python .github/workflows/coverage_helper.py - echo "COVERAGE_PERCENT=$(cat temp2.txt)" >> $GITHUB_ENV - - - name: Request Changes via Review - if: ${{ env.COVERAGE_PERCENT < 90 }} - uses: andrewmusgrave/automatic-pull-request-review@0.0.5 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - event: REQUEST_CHANGES - body: "Increase test coverage from ${{ env.COVERAGE_PERCENT }}% to at least 90% before merging." - - - name: Approve PR if Coverage Sufficient - if: ${{ env.COVERAGE_PERCENT > 89 }} - uses: andrewmusgrave/automatic-pull-request-review@0.0.5 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - event: APPROVE - body: "Test coverage meets or exceeds 90% threshold (currently ${{ env.COVERAGE_PERCENT }}%)." diff --git a/.github/workflows/format_code.yml b/.github/workflows/format_code.yml deleted file mode 100644 index cea7931c..00000000 --- a/.github/workflows/format_code.yml +++ /dev/null @@ -1,31 +0,0 @@ -name: Ensure Code Formatting -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - - workflow_dispatch: - -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Setup python - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: x64 - - name: Install Dependencies - run: | - python -m pip install pycodestyle autopep8 isort - python -m pip install -e . - - - name: Check Code Formatting Errors - run: | - pycodestyle --statistics --count --max-line-length=150 --show-source --ignore=E203 . - - - name: Check Import Ordering Errors - run: | - isort --check-only --verbose . \ No newline at end of file diff --git a/.github/workflows/ipynb_ci.yml b/.github/workflows/ipynb_ci.yml deleted file mode 100644 index df91d2ec..00000000 --- a/.github/workflows/ipynb_ci.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: Check Jupyer Notebook Execution -on: - push: - branches: [main] - pull_request: - branches: [main] - - workflow_dispatch: - -jobs: - build: - strategy: - fail-fast: false - matrix: - nb-file: - ["barrier_prediction_with_RDB7/RDB7_barrier_prediction_example", "train_val_test_split_sklearn_example/train_val_test_split_example", "split_comparisons/split_comparisons", "mlpds_2023_astartes_demonstration/mlpds_2023_demo"] - runs-on: ubuntu-latest - defaults: - run: - shell: bash -el {0} - name: Check ${{ matrix.nb-file }} Notebook Execution - steps: - - uses: actions/checkout@v3 - - uses: mamba-org/setup-micromamba@main - with: - environment-name: temp - condarc: | - channels: - - defaults - - conda-forge - channel_priority: flexible - create-args: | - python=3.11 - - name: Install dependencies - run: | - python -m pip install -e .[molecules,demos] - python -m pip install notebook - - name: Test Execution - run: | - cd examples/$(dirname ${{ matrix.nb-file }}) - jupyter nbconvert --to script $(basename ${{ matrix.nb-file }}).ipynb - ipython $(basename ${{ matrix.nb-file }}).py diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml deleted file mode 100644 index a5b0ce48..00000000 --- a/.github/workflows/run_tests.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: Run Tests -on: - schedule: - - cron: "0 8 * * 1-5" - push: - branches: [main] - pull_request: - branches: [main] - - workflow_dispatch: - -concurrency: - group: actions-id-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -jobs: - build: - strategy: - fail-fast: false - matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] - os: [ubuntu-latest, windows-latest, macos-latest] - - runs-on: ${{ matrix.os }} - defaults: - run: - shell: bash -el {0} - name: ${{ matrix.os }} Python ${{ matrix.python-version }} Subtest - steps: - - uses: actions/checkout@v3 - - uses: mamba-org/setup-micromamba@main - with: - environment-name: temp - condarc: | - channels: - - defaults - - conda-forge - channel_priority: flexible - create-args: | - python=${{ matrix.python-version }} - - name: Install Dependencies - run: | - python -m pip install -e .[molecules] - python -m pip install coverage - - name: Run Tests - run: | - coverage run --source=. --omit=astartes/__init__.py,setup.py,test/* -m unittest discover -v - - name: Show Coverage - run: | - coverage report -m diff --git a/README.md b/README.md index 07fc8064..35420e21 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Follow [this link](https://JacksonBurns.github.io/astartes/) for a nicely-render Keep reading for a installation guide and links to tutorials! ## Installing `astartes` -We recommend installing `astartes` within a virtual environment, using either `venv` or `conda` (or other tools) to simplify dependency management. Python versions 3.7, 3.8, 3.9, 3.10, 3.11, and 3.12 are supported on all platforms. +We recommend installing `astartes` within a virtual environment, using either `venv` or `conda` (or other tools) to simplify dependency management. Python versions 3.8, 3.9, 3.10, 3.11, and 3.12 are supported on all platforms. > **Warning** > Windows (PowerShell) and MacOS Catalina or newer (zsh) require double quotes around text using the `'[]'` characters (i.e. `pip install "astartes[molecules]"`). @@ -226,8 +226,10 @@ Do not provide a `random_state` in the `hopts` dictionary - it will be overwritt | Sample set Partitioning based on joint X-Y distances (SPXY) | 'spxy' | Interpolative | `distance_metric` | Saldhana et. al [original paper](https://www.sciencedirect.com/science/article/abs/pii/S003991400500192X) :small_blue_diamond: | Extension of Kennard Stone that also includes the response when sampling distances. | | Mahalanobis Distance Kennard Stone (MDKS) | 'spxy' _(MDKS is derived from SPXY)_ | Interpolative | _none, see Notes_ | Saptoro et. al [original paper](https://espace.curtin.edu.au/bitstream/handle/20.500.11937/45101/217844_70585_PUB-SE-DCE-FM-71008.pdf?sequence=2&isAllowed=y) | MDKS is SPXY using Mahalanobis distance and can be called by using SPXY with `distance_metric="mahalanobis"` | | Scaffold | 'scaffold' | Extrapolative | `include_chirality` | [Bemis-Murcko Scaffold](https://pubs.acs.org/doi/full/10.1021/jm9602928) :small_blue_diamond: as implemented in RDKit | This sampler requires SMILES strings as input (use the `molecules` subpackage) | +| Molecular Weight| 'molecular_weight' | Extrapolative | _none_ | ~ | Sorts molecules by molecular weight as calculated by RDKit | | Sphere Exclusion | 'sphere_exclusion' | Extrapolative | `metric`, `distance_cutoff` | _custom implementation_ | Variation on Sphere Exclusion for arbitrary-valued vectors. | | Time Based | 'time_based' | Extrapolative | _none_ | Papers using Time based splitting: [Chen et al.](https://pubs.acs.org/doi/full/10.1021/ci200615h) :small_blue_diamond:, [Sheridan, R. P](https://pubs.acs.org/doi/full/10.1021/ci400084k) :small_blue_diamond:, [Feinberg et al.](https://pubs.acs.org/doi/full/10.1021/acs.jmedchem.9b02187) :small_blue_diamond:, [Struble et al.](https://pubs.rsc.org/en/content/articlehtml/2020/re/d0re00071j) | This sampler requires `labels` to be an iterable of either date or datetime objects. | +| Target Property | 'target_property' | Extrapolative | `descending` | ~ | Sorts data by regression target y | | Optimizable K-Dissimilarity Selection (OptiSim) | 'optisim' | Extrapolative | `n_clusters`, `max_subsample_size`, `distance_cutoff` | _custom implementation_ | Variation on [OptiSim](https://pubs.acs.org/doi/10.1021/ci025662h) for arbitrary-valued vectors. | | K-Means | 'kmeans' | Extrapolative | `n_clusters`, `n_init` | [`sklearn KMeans`](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html) | Passthrough to `sklearn`'s `KMeans`. | | Density-Based Spatial Clustering of Applications with Noise (DBSCAN) | 'dbscan' | Extrapolative | `eps`, `min_samples`, `algorithm`, `metric`, `leaf_size` | [`sklearn DBSCAN`](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html) Documentation| Passthrough to `sklearn`'s `DBSCAN`. | diff --git a/astartes/__init__.py b/astartes/__init__.py index 14006d2d..807b453f 100644 --- a/astartes/__init__.py +++ b/astartes/__init__.py @@ -1,7 +1,7 @@ # convenience import to enable 'from astartes import train_test_split' from .main import train_test_split, train_val_test_split -__version__ = "1.1.5" +__version__ = "1.2.0" # DO NOT do this: # from .molecules import train_test_split_molecules diff --git a/astartes/main.py b/astartes/main.py index 8a9b197c..56121f4f 100644 --- a/astartes/main.py +++ b/astartes/main.py @@ -6,6 +6,7 @@ import pandas as pd from astartes.samplers import ( + DETERMINISTIC_EXTRAPOLATION_SAMPLERS, IMPLEMENTED_EXTRAPOLATION_SAMPLERS, IMPLEMENTED_INTERPOLATION_SAMPLERS, ) @@ -69,22 +70,19 @@ def train_val_test_split( if labels is not None and len(labels) != len(X): msg += "len(labels)={:d} ".format(len(labels)) if msg: - raise InvalidConfigurationError( - "Lengths of input arrays do not match: len(X)={:d} ".format(len(X)) + msg - ) + raise InvalidConfigurationError("Lengths of input arrays do not match: len(X)={:d} ".format(len(X)) + msg) - train_size, val_size, test_size = _normalize_split_sizes( - train_size, val_size, test_size - ) - hopts["random_state"] = ( - random_state if random_state is not None else DEFAULT_RANDOM_STATE - ) + train_size, val_size, test_size = _normalize_split_sizes(train_size, val_size, test_size) + hopts["random_state"] = random_state if random_state is not None else DEFAULT_RANDOM_STATE sampler_factory = SamplerFactory(sampler) sampler_instance = sampler_factory.get_sampler(X, y, labels, hopts) - if sampler in (*IMPLEMENTED_INTERPOLATION_SAMPLERS, "time_based"): - # time_based does extrapolation but does not support random_state - # because it always sorts in time order + if sampler in ( + *IMPLEMENTED_INTERPOLATION_SAMPLERS, + *DETERMINISTIC_EXTRAPOLATION_SAMPLERS, + ): + # extrapolation samplers in DETERMINISTIC_EXTRAPOLATION_SAMPLERS do not accept the + # random_state argument because they are entirely deterministic return _interpolative_sampling( sampler_instance, test_size, @@ -189,11 +187,7 @@ def _extrapolative_sampling( # largest clusters must go into largest set, but smaller ones can optionally # be shuffled - cluster_counter = sampler_instance.get_sorted_cluster_counter( - max_shufflable_size=max_shufflable_size - if random_state is not None - else None - ) + cluster_counter = sampler_instance.get_sorted_cluster_counter(max_shufflable_size=max_shufflable_size if random_state is not None else None) test_idxs, val_idxs, train_idxs = ( np.array([], dtype=int), @@ -203,20 +197,12 @@ def _extrapolative_sampling( for cluster_idx, cluster_length in cluster_counter.items(): # assemble test first, avoid overfilling if (len(test_idxs) + cluster_length) <= n_test_samples: - test_idxs = np.append( - test_idxs, sampler_instance.get_sample_idxs(cluster_length) - ) + test_idxs = np.append(test_idxs, sampler_instance.get_sample_idxs(cluster_length)) elif (len(val_idxs) + cluster_length) <= n_val_samples: - val_idxs = np.append( - val_idxs, sampler_instance.get_sample_idxs(cluster_length) - ) + val_idxs = np.append(val_idxs, sampler_instance.get_sample_idxs(cluster_length)) else: # then balance goes into train - train_idxs = np.append( - train_idxs, sampler_instance.get_sample_idxs(cluster_length) - ) - _check_actual_split( - train_idxs, val_idxs, test_idxs, train_size, val_size, test_size - ) + train_idxs = np.append(train_idxs, sampler_instance.get_sample_idxs(cluster_length)) + _check_actual_split(train_idxs, val_idxs, test_idxs, train_size, val_size, test_size) return return_helper( sampler_instance, train_idxs, @@ -261,9 +247,7 @@ def _interpolative_sampling( val_idxs = sampler_instance.get_sample_idxs(n_val_samples) test_idxs = sampler_instance.get_sample_idxs(n_test_samples) - _check_actual_split( - train_idxs, val_idxs, test_idxs, train_size, val_size, test_size - ) + _check_actual_split(train_idxs, val_idxs, test_idxs, train_size, val_size, test_size) return return_helper( sampler_instance, train_idxs, @@ -377,27 +361,17 @@ def _normalize_split_sizes(train_size, val_size, test_size): else: # one or the other - only allow floats [0, 1), then calculate if train_size: if train_size >= 1.0 or train_size <= 0: - raise InvalidConfigurationError( - "If specifying only train_size, must be float between (0, 1) (got {:.2f})".format( - train_size - ) - ) + raise InvalidConfigurationError("If specifying only train_size, must be float between (0, 1) (got {:.2f})".format(train_size)) test_size = 1.0 - train_size out = [train_size, 0, test_size] else: if test_size >= 1.0 or test_size <= 0: - raise InvalidConfigurationError( - "If specifying only test_size, must be float between (0, 1) (got {:.2f})".format( - test_size - ) - ) + raise InvalidConfigurationError("If specifying only test_size, must be float between (0, 1) (got {:.2f})".format(test_size)) train_size = 1.0 - test_size out = [train_size, 0, test_size] else: # there is a non-zero val_size if val_size >= 1.0 or val_size <= 0.0: - raise InvalidConfigurationError( - "val_size must be a float between (0, 1) (got {:.2f})".format(val_size) - ) + raise InvalidConfigurationError("val_size must be a float between (0, 1) (got {:.2f})".format(val_size)) if train_size and test_size: # all three - normalize if train_size + test_size + val_size != 1.0: normalization = train_size + test_size + val_size @@ -423,18 +397,14 @@ def _normalize_split_sizes(train_size, val_size, test_size): if train_size: if train_size >= 1.0 or train_size <= 0: raise InvalidConfigurationError( - "If specifying val_size and only train_size, must be float between (0, 1) (got {:.2f})".format( - train_size - ) + "If specifying val_size and only train_size, must be float between (0, 1) (got {:.2f})".format(train_size) ) test_size = 1.0 - (train_size + val_size) out = [train_size, val_size, test_size] else: if test_size >= 1.0 or test_size <= 0: raise InvalidConfigurationError( - "If specifying val_size and only test_size, must be float between (0, 1) (got {:.2f})".format( - test_size - ) + "If specifying val_size and only test_size, must be float between (0, 1) (got {:.2f})".format(test_size) ) train_size = 1.0 - (test_size + val_size) out = [train_size, val_size, test_size] diff --git a/astartes/molecules.py b/astartes/molecules.py index 1e694ba7..2fbb08b5 100644 --- a/astartes/molecules.py +++ b/astartes/molecules.py @@ -1,30 +1,11 @@ -import warnings - import numpy as np -from astartes.utils.exceptions import MoleculesNotInstalledError - -try: - """ - aimsim depends on sklearn_extra, which uses a version checking technique that is due to - be deprecated in a version of Python after 3.11, so it is throwing a deprecation warning - We ignore this warning since we can't do anything about it (sklearn_extra seems to be - abandonware) and in the future it will become an error that we can deal with. - """ - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=DeprecationWarning) - from aimsim.chemical_datastructures import Molecule - from aimsim.exceptions import LoadingError -except ImportError: # pragma: no cover - raise MoleculesNotInstalledError( - """To use molecule featurizer, install astartes with pip install astartes[molecules].""" - ) - # at this point we have successfully verified that rdkit is installed, so we can do this: from rdkit.rdBase import SeedRandomNumberGenerator from astartes import train_test_split, train_val_test_split from astartes.main import DEFAULT_RANDOM_STATE +from astartes.utils.aimsim_featurizer import featurize_molecules SeedRandomNumberGenerator(DEFAULT_RANDOM_STATE) @@ -65,7 +46,7 @@ def train_val_test_split_molecules( if sampler == "scaffold": X = molecules else: - X = _featurize(molecules, fingerprint, fprints_hopts) + X = featurize_molecules(molecules, fingerprint, fprints_hopts) return train_val_test_split( X, y=y, @@ -115,7 +96,7 @@ def train_test_split_molecules( if sampler == "scaffold": X = molecules else: - X = _featurize(molecules, fingerprint, fprints_hopts) + X = featurize_molecules(molecules, fingerprint, fprints_hopts) # call train test split with this input return train_test_split( @@ -129,37 +110,3 @@ def train_test_split_molecules( hopts=hopts, return_indices=return_indices, ) - - -def _featurize(molecules, fingerprint, fprints_hopts): - """Call AIMSim's Molecule to featurize the molecules according to the arguments. - - Args: - molecules (np.array): SMILES strings or RDKit molecule objects. - fingerprint (str): The molecular fingerprint to be used. - fprints_hopts (dict): Hyperparameters for AIMSim. - - Returns: - np.array: X array (featurized molecules) - """ - X = [] - for molecule in molecules: - try: - if type(molecule) in (np.str_, str): - mol = Molecule(mol_smiles=molecule) - else: - mol = Molecule(mol_graph=molecule) - except LoadingError as le: - raise RuntimeError( - "Unable to featurize molecules using '{:s}' with this configuration: fprint_hopts={:s}" - "\nCheck terminal output for messages from the RDkit logger. ".format( - fingerprint, repr(fprints_hopts) - ) - ) from le - mol.descriptor.make_fingerprint( - mol.mol_graph, - fingerprint, - fingerprint_params=fprints_hopts, - ) - X.append(mol.descriptor.to_numpy()) - return np.array(X) diff --git a/astartes/samplers/__init__.py b/astartes/samplers/__init__.py index eaff11cc..8beaa651 100644 --- a/astartes/samplers/__init__.py +++ b/astartes/samplers/__init__.py @@ -2,7 +2,16 @@ from .abstract_sampler import AbstractSampler # implementations -from .extrapolation import DBSCAN, KMeans, OptiSim, Scaffold, SphereExclusion, TimeBased +from .extrapolation import ( + DBSCAN, + KMeans, + MolecularWeight, + OptiSim, + Scaffold, + SphereExclusion, + TargetProperty, + TimeBased, +) from .interpolation import SPXY, KennardStone, Random IMPLEMENTED_INTERPOLATION_SAMPLERS = ( @@ -15,9 +24,17 @@ "dbscan", "scaffold", "kmeans", + "molecular_weight", "optisim", "sphere_exclusion", "time_based", + "target_property", ) ALL_SAMPLERS = IMPLEMENTED_EXTRAPOLATION_SAMPLERS + IMPLEMENTED_INTERPOLATION_SAMPLERS + +DETERMINISTIC_EXTRAPOLATION_SAMPLERS = ( + "time_based", + "target_property", + "molecular_weight", +) diff --git a/astartes/samplers/abstract_sampler.py b/astartes/samplers/abstract_sampler.py index 75702383..db4cf2cf 100644 --- a/astartes/samplers/abstract_sampler.py +++ b/astartes/samplers/abstract_sampler.py @@ -1,4 +1,5 @@ """Abstract Sampling class""" + from abc import ABC, abstractmethod from collections import Counter @@ -30,6 +31,7 @@ def __init__(self, X, y, labels, configs): self._current_sample_idx = 0 self._before_sample() self._sample() + self._after_sample() def _before_sample(self): """This method should perform any data validation, manipulation, etc. required before proceeding to _sample @@ -38,6 +40,13 @@ def _before_sample(self): None: Returns nothing, raises an Exception if something is wrong. """ + def _after_sample(self): + """This method should perform any checks, mutations, etc. required after _sample is completed. + + Returns: + None: Returns nothing, raises an Exception if something is wrong. + """ + @abstractmethod def _sample(self): """ diff --git a/astartes/samplers/extrapolation/__init__.py b/astartes/samplers/extrapolation/__init__.py index 36dbb823..fe093727 100644 --- a/astartes/samplers/extrapolation/__init__.py +++ b/astartes/samplers/extrapolation/__init__.py @@ -1,6 +1,8 @@ from .dbscan import DBSCAN from .kmeans import KMeans +from .molecular_weight import MolecularWeight from .optisim import OptiSim from .scaffold import Scaffold from .sphere_exclusion import SphereExclusion +from .target_property import TargetProperty from .time_based import TimeBased diff --git a/astartes/samplers/extrapolation/molecular_weight.py b/astartes/samplers/extrapolation/molecular_weight.py new file mode 100644 index 00000000..0d62cd2c --- /dev/null +++ b/astartes/samplers/extrapolation/molecular_weight.py @@ -0,0 +1,32 @@ +""" +This sampler partitions the data based on molecular weight. It first sorts the +molecules by molecular weight and then places the smallest molecules in the training set, +the next smallest in the validation set if applicable, and finally the largest molecules +in the testing set. +""" + +import numpy as np + +try: + from astartes.utils.aimsim_featurizer import featurize_molecules +except ImportError: + # this is in place so that the import of this from parent directory will work + # if it fails, it is caught in molecules instead and the error is more helpful + NO_MOLECULES = True + +from .scaffold import Scaffold +from .target_property import TargetProperty + + +# inherit sample method from TargetProperty +class MolecularWeight(TargetProperty): + def _before_sample(self): + # check for invalid data types using the method in the Scaffold sampler + Scaffold._validate_input(self.X) + # calculate the average molecular weight of the molecule + self.y_backup = self.y + self.y = featurize_molecules((Scaffold.str_to_mol(i) for i in self.X), "mordred:MW", fprints_hopts={}) + + def _after_sample(self): + # restore the original y values + self.y = self.y_backup diff --git a/astartes/samplers/extrapolation/scaffold.py b/astartes/samplers/extrapolation/scaffold.py index 74b7e601..b3856902 100644 --- a/astartes/samplers/extrapolation/scaffold.py +++ b/astartes/samplers/extrapolation/scaffold.py @@ -9,6 +9,7 @@ that are not in the training set. """ + import warnings from collections import defaultdict @@ -27,12 +28,15 @@ class Scaffold(AbstractSampler): - def _before_sample(self): + def _validate_input(X): # ensure that X contains entries that are either a SMILES string or an RDKit Molecule - if not all(isinstance(i, str) for i in self.X) and not all(isinstance(i, Chem.rdchem.Mol) for i in self.X): + if not all(isinstance(i, str) for i in X) and not all(isinstance(i, Chem.rdchem.Mol) for i in X): msg = "Scaffold class requires input X to be an iterable of SMILES strings, InChI strings, or RDKit Molecules" raise TypeError(msg) + def _before_sample(self): + Scaffold._validate_input(self.X) + def _sample(self): """Implements the Scaffold sampler to identify clusters via a molecule's Bemis-Murcko scaffold.""" scaffold_to_indices = self.scaffold_to_smiles(self.X) @@ -42,8 +46,7 @@ def _sample(self): for cluster_id, (scaffold, indices) in enumerate(scaffold_to_indices.items()): if scaffold == "": warnings.warn( - f"No matching scaffold was found for the {len(indices)} " - f"molecules corresponding to indices {indices}", + f"No matching scaffold was found for the {len(indices)} " f"molecules corresponding to indices {indices}", NoMatchingScaffold, ) for idx in indices: @@ -63,12 +66,12 @@ def scaffold_to_smiles(self, mols): """ scaffolds = defaultdict(set) for i, mol in enumerate(mols): - scaffold = self.generate_bemis_murcko_scaffold(mol, self.get_config("include_chirality", False)) + scaffold = Scaffold.generate_bemis_murcko_scaffold(mol, self.get_config("include_chirality", False)) scaffolds[scaffold].add(i) return scaffolds - def str_to_mol(self, string): + def str_to_mol(string): """ Converts an InChI or SMILES string to an RDKit molecule. @@ -92,7 +95,7 @@ def str_to_mol(self, string): return mol - def generate_bemis_murcko_scaffold(self, mol, include_chirality=False): + def generate_bemis_murcko_scaffold(mol, include_chirality=False): """ Compute the Bemis-Murcko scaffold for an RDKit molecule. @@ -103,7 +106,7 @@ def generate_bemis_murcko_scaffold(self, mol, include_chirality=False): Returns: Bemis-Murcko scaffold """ - mol = self.str_to_mol(mol) if isinstance(mol, str) else mol + mol = Scaffold.str_to_mol(mol) if isinstance(mol, str) else mol scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality) return scaffold diff --git a/astartes/samplers/extrapolation/target_property.py b/astartes/samplers/extrapolation/target_property.py new file mode 100644 index 00000000..0fe63c94 --- /dev/null +++ b/astartes/samplers/extrapolation/target_property.py @@ -0,0 +1,23 @@ +""" +This sampler partitions the data based on the regression target y. It first sorts the +data by y value and then constructs the training set to have either the smallest (largest) +y values, the validation set to have the next smallest (largest) set of y values, and the +testing set to have the largest (smallest) y values. +""" + +import numpy as np + +from astartes.samplers import AbstractSampler + + +class TargetProperty(AbstractSampler): + def _sample(self): + """ + Implements the target property sampler to create an extrapolation split. + """ + data = [(y, idx) for y, idx in zip(self.y, np.arange(len(self.y)))] + + # by default, the smallest property values are placed in the training set + sorted_list = sorted(data, reverse=self.get_config("descending", False)) + + self._samples_idxs = np.array([idx for time, idx in sorted_list], dtype=int) diff --git a/astartes/utils/aimsim_featurizer.py b/astartes/utils/aimsim_featurizer.py new file mode 100644 index 00000000..ff85e3a8 --- /dev/null +++ b/astartes/utils/aimsim_featurizer.py @@ -0,0 +1,51 @@ +import warnings + +import numpy as np + +from astartes.utils.exceptions import MoleculesNotInstalledError + +try: + """ + aimsim depends on sklearn_extra, which uses a version checking technique that is due to + be deprecated in a version of Python after 3.11, so it is throwing a deprecation warning + We ignore this warning since we can't do anything about it (sklearn_extra seems to be + abandonware) and in the future it will become an error that we can deal with. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + from aimsim.chemical_datastructures import Molecule + from aimsim.exceptions import LoadingError +except ImportError: # pragma: no cover + raise MoleculesNotInstalledError("""To use molecule featurizer, install astartes with pip install astartes[molecules].""") + + +def featurize_molecules(molecules, fingerprint, fprints_hopts): + """Call AIMSim's Molecule to featurize the molecules according to the arguments. + + Args: + molecules (np.array): SMILES strings or RDKit molecule objects. + fingerprint (str): The molecular fingerprint to be used. + fprints_hopts (dict): Hyperparameters for AIMSim. + + Returns: + np.array: X array (featurized molecules) + """ + X = [] + for molecule in molecules: + try: + if type(molecule) in (np.str_, str): + mol = Molecule(mol_smiles=molecule) + else: + mol = Molecule(mol_graph=molecule) + except LoadingError as le: + raise RuntimeError( + "Unable to featurize molecules using '{:s}' with this configuration: fprint_hopts={:s}" + "\nCheck terminal output for messages from the RDkit logger. ".format(fingerprint, repr(fprints_hopts)) + ) from le + mol.descriptor.make_fingerprint( + mol.mol_graph, + fingerprint, + fingerprint_params=fprints_hopts, + ) + X.append(mol.descriptor.to_numpy()) + return np.array(X) diff --git a/astartes/utils/sampler_factory.py b/astartes/utils/sampler_factory.py index 305f13ef..28d869e0 100644 --- a/astartes/utils/sampler_factory.py +++ b/astartes/utils/sampler_factory.py @@ -6,10 +6,12 @@ SPXY, KennardStone, KMeans, + MolecularWeight, OptiSim, Random, Scaffold, SphereExclusion, + TargetProperty, TimeBased, ) from astartes.utils.exceptions import SamplerNotImplementedError @@ -49,14 +51,18 @@ def get_sampler(self, X, y, labels, hopts): sampler_class = KMeans elif self.sampler == "sphere_exclusion": sampler_class = SphereExclusion + elif self.sampler == "molecular_weight": + sampler_class = MolecularWeight elif self.sampler == "optisim": sampler_class = OptiSim elif self.sampler == "spxy": sampler_class = SPXY elif self.sampler == "scaffold": sampler_class = Scaffold - elif self.sampler == 'time_based': + elif self.sampler == "time_based": sampler_class = TimeBased + elif self.sampler == "target_property": + sampler_class = TargetProperty else: possiblity = get_close_matches( self.sampler, diff --git a/pyproject.toml b/pyproject.toml index 1b4c993b..b7a9c402 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ authors = [ license = { text = "MIT" } description = "Train:Test Algorithmic Sampling for Molecules and Arbitrary Arrays" classifiers = [ - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -21,7 +20,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] urls = { Homepage = "https://github.com/JacksonBurns/astartes" } -requires-python = ">=3.7" +requires-python = ">=3.8" dependencies = ["scikit_learn", "tabulate", "numpy", "scipy", "pandas"] [project.optional-dependencies] diff --git a/test/functional/test_molecules.py b/test/functional/test_molecules.py index 52556d6b..d1c77258 100644 --- a/test/functional/test_molecules.py +++ b/test/functional/test_molecules.py @@ -11,6 +11,7 @@ train_val_test_split_molecules, ) from astartes.samplers import ( + DETERMINISTIC_EXTRAPOLATION_SAMPLERS, IMPLEMENTED_EXTRAPOLATION_SAMPLERS, IMPLEMENTED_INTERPOLATION_SAMPLERS, ) @@ -82,7 +83,7 @@ def test_validation_split_molecules(self): with warnings.catch_warnings(): warnings.simplefilter("ignore") for sampler in IMPLEMENTED_EXTRAPOLATION_SAMPLERS: - if sampler in ("scaffold", "time_based"): + if sampler in ("scaffold", *DETERMINISTIC_EXTRAPOLATION_SAMPLERS): continue tts = train_val_test_split_molecules( self.X, @@ -142,9 +143,7 @@ def test_sampler_hopts(self): "\nNo warnings should have been raised when requesting a mathematically possible split." "\nReceived {:d} warnings instead: \n -> {:s}".format( len(w), - "\n -> ".join( - [str(i.category.__name__) + ": " + str(i.message) for i in w] - ), + "\n -> ".join([str(i.category.__name__) + ": " + str(i.message) for i in w]), ), ) diff --git a/test/regression/test_regression.py b/test/regression/test_regression.py index b3f81866..7e304aad 100644 --- a/test/regression/test_regression.py +++ b/test/regression/test_regression.py @@ -9,13 +9,12 @@ from astartes import train_val_test_split from astartes.samplers import ( ALL_SAMPLERS, + DETERMINISTIC_EXTRAPOLATION_SAMPLERS, IMPLEMENTED_EXTRAPOLATION_SAMPLERS, IMPLEMENTED_INTERPOLATION_SAMPLERS, ) -SKLEARN_GEQ_13 = ( # get the sklearn version - int(pkg_resources.get_distribution("scikit-learn").version.split(".")[1]) >= 3 -) +SKLEARN_GEQ_13 = int(pkg_resources.get_distribution("scikit-learn").version.split(".")[1]) >= 3 # get the sklearn version class Test_regression(unittest.TestCase): @@ -29,13 +28,9 @@ def setUpClass(self): rng = np.random.default_rng(42) self.X = rng.random((100, 100)) self.y = rng.random((100,)) - self.labels_datetime = np.array( - [datetime.strptime(f"20{y:02}/01/01", "%Y/%m/%d") for y in range(100)] - ) + self.labels_datetime = np.array([datetime.strptime(f"20{y:02}/01/01", "%Y/%m/%d") for y in range(100)]) cwd = os.getcwd() - self.reference_splits_dir = os.path.join( - cwd, "test", "regression", "reference_splits" - ) + self.reference_splits_dir = os.path.join(cwd, "test", "regression", "reference_splits") self.reference_splits = { name: os.path.join(self.reference_splits_dir, name + "_reference.pkl") for name in ALL_SAMPLERS @@ -87,9 +82,7 @@ def test_timebased_regression(self): with open(self.reference_splits["time_based"], "rb") as f: reference_output = pkl.load(f) for i, j in zip(all_output, reference_output): - np.testing.assert_array_equal( - i, j, "Sampler time_based failed regression testing." - ) + np.testing.assert_array_equal(i, j, "Sampler time_based failed regression testing.") def test_interpolation_regression(self): """Regression testing of interpolative methods relative to static results.""" @@ -104,14 +97,12 @@ def test_interpolation_regression(self): with open(self.reference_splits[sampler_name], "rb") as f: reference_output = pkl.load(f) for i, j in zip(all_output, reference_output): - np.testing.assert_array_equal( - i, j, "Sampler {:s} failed regression testing.".format(sampler_name) - ) + np.testing.assert_array_equal(i, j, "Sampler {:s} failed regression testing.".format(sampler_name)) def test_extrapolation_regression(self): """Regression testing of extrapolative methods relative to static results.""" for sampler_name in IMPLEMENTED_EXTRAPOLATION_SAMPLERS: - if sampler_name in ("scaffold", "time_based", "kmeans"): + if sampler_name in ("scaffold", "kmeans", *DETERMINISTIC_EXTRAPOLATION_SAMPLERS): continue ( X_train, diff --git a/test/unit/samplers/extrapolative/test_molecular_weight.py b/test/unit/samplers/extrapolative/test_molecular_weight.py new file mode 100644 index 00000000..39e0ba3b --- /dev/null +++ b/test/unit/samplers/extrapolative/test_molecular_weight.py @@ -0,0 +1,165 @@ +import unittest + +import numpy as np + +from astartes import train_test_split +from astartes.samplers import MolecularWeight + + +class Test_MolecularWeight(unittest.TestCase): + """ + Test the various functionalities of MolecularWeight. + """ + + @classmethod + def setUpClass(self): + """Convenience attributes for later tests.""" + self.X = np.array( + [ + "C", + "CC", + "CCC", + "CCCC", + "CCCCC", + "CCCCCC", + "CCCCCCC", + "CCCCCCCC", + "CCCCCCCCC", + "CCCCCCCCCC", + ] + ) + self.X_inchi = np.array( + [ + "InChI=1S/CH4/h1H4", + "InChI=1S/C2H6/c1-2/h1-2H3", + "InChI=1S/C3H8/c1-3-2/h3H2,1-2H3", + "InChI=1S/C4H10/c1-3-4-2/h3-4H2,1-2H3", + "InChI=1S/C5H12/c1-3-5-4-2/h3-5H2,1-2H3", + "InChI=1S/C6H14/c1-3-5-6-4-2/h3-6H2,1-2H3", + "InChI=1S/C7H16/c1-3-5-7-6-4-2/h3-7H2,1-2H3", + "InChI=1S/C8H18/c1-3-5-7-8-6-4-2/h3-8H2,1-2H3", + "InChI=1S/C9H20/c1-3-5-7-9-8-6-4-2/h3-9H2,1-2H3", + "InChI=1S/C10H22/c1-3-5-7-9-10-8-6-4-2/h3-10H2,1-2H3", + ] + ) + self.y = np.arange(len(self.X)) + self.labels = np.array( + [ + "methane", + "ethane", + "propane", + "butane", + "pentane", + "hexane", + "heptane", + "octane", + "nonane", + "decane", + ] + ) + + def test_molecular_weight_sampling(self): + """Use MolecularWeight in the train_test_split and verify results.""" + ( + X_train, + X_test, + y_train, + y_test, + labels_train, + labels_test, + ) = train_test_split( + self.X, + self.y, + labels=self.labels, + test_size=0.2, + train_size=0.8, + sampler="molecular_weight", + hopts={}, + ) + + # test that the known arrays equal the result from above + self.assertIsNone( + np.testing.assert_array_equal( + X_train, + self.X[:8], # X was already sorted by ascending molecular weight + ), + "Train X incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + X_test, + self.X[8:], # X was already sorted by ascending molecular weight + ), + "Test X incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + y_train, + self.y[:8], # y was already sorted by ascending molecular weight + ), + "Train y incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + y_test, + self.y[8:], # y was already sorted by ascending molecular weight + ), + "Test y incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + labels_train, + self.labels[:8], # labels was already sorted by ascending molecular weight + ), + "Train labels incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + labels_test, + self.labels[8:], # labels was already sorted by ascending molecular weight + ), + "Test labels incorrect.", + ) + + def test_molecular_weight(self): + """Directly instantiate and test MolecularWeight.""" + molecular_weight_instance = MolecularWeight( + self.X, + self.y, + self.labels, + {}, + ) + self.assertIsInstance( + molecular_weight_instance, + MolecularWeight, + "Failed instantiation.", + ) + self.assertFalse( + len(molecular_weight_instance.get_clusters()), + "Clusters was set when it should not have been.", + ) + self.assertTrue( + len(molecular_weight_instance._samples_idxs), + "Sample indices not set.", + ) + + def test_incorrect_input(self): + """Calling with something other than SMILES, InChI, or RDKit Molecule should raise TypeError""" + with self.assertRaises(TypeError): + train_test_split( + np.array([[1], [2]]), + sampler="molecular_weight", + ) + + def test_mol_from_inchi(self): + """Ability to load data from InChi inputs""" + MolecularWeight( + self.X_inchi, + None, + None, + {}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit/samplers/extrapolative/test_target_property.py b/test/unit/samplers/extrapolative/test_target_property.py new file mode 100644 index 00000000..44a5b5a3 --- /dev/null +++ b/test/unit/samplers/extrapolative/test_target_property.py @@ -0,0 +1,194 @@ +import unittest + +import numpy as np + +from astartes import train_test_split +from astartes.samplers import TargetProperty + + +class Test_TargetProperty(unittest.TestCase): + """ + Test the various functionalities of TargetProperty. + """ + + @classmethod + def setUpClass(self): + """Convenience attributes for later tests.""" + self.X = np.array( + [ + "C", + "CC", + "CCC", + "CCCC", + "CCCCC", + "CCCCCC", + "CCCCCCC", + "CCCCCCCC", + "CCCCCCCCC", + "CCCCCCCCCC", + ] + ) + + self.y = np.arange(len(self.X)) + self.labels = np.array( + [ + "methane", + "ethane", + "propane", + "butane", + "pentane", + "hexane", + "heptane", + "octane", + "nonane", + "decane", + ] + ) + + def test_target_property_sampling_ascending(self): + """Use TargetProperty in the train_test_split and verify results.""" + ( + X_train, + X_test, + y_train, + y_test, + labels_train, + labels_test, + ) = train_test_split( + self.X, + self.y, + labels=self.labels, + test_size=0.2, + train_size=0.8, + sampler="target_property", + hopts={"descending": False}, + ) + + # test that the known arrays equal the result from above + self.assertIsNone( + np.testing.assert_array_equal( + X_train, + self.X[:8], # X was already sorted by ascending target value + ), + "Train X incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + X_test, + self.X[8:], # X was already sorted by ascending target value + ), + "Test X incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + y_train, + self.y[:8], # y was already sorted by ascending target value + ), + "Train y incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + y_test, + self.y[8:], # y was already sorted by ascending target value + ), + "Test y incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + labels_train, + self.labels[:8], # labels was already sorted by ascending target value + ), + "Train labels incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + labels_test, + self.labels[8:], # labels was already sorted by ascending target value + ), + "Test labels incorrect.", + ) + + def test_target_property_sampling_descending(self): + """Use TargetProperty in the train_test_split and verify results.""" + ( + X_train, + X_test, + y_train, + y_test, + labels_train, + labels_test, + ) = train_test_split( + self.X, + self.y, + labels=self.labels, + test_size=0.2, + train_size=0.8, + sampler="target_property", + hopts={"descending": True}, + ) + + # test that the known arrays equal the result from above + self.assertIsNone( + np.testing.assert_array_equal( + X_train, + np.flip(self.X)[:8], + ), + "Train X incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + X_test, + np.flip(self.X)[8:], + ), + "Test X incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + y_train, + np.flip(self.y)[:8], + ), + "Train y incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + y_test, + np.flip(self.y)[8:], + ), + "Test y incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + labels_train, + np.flip(self.labels)[:8], + ), + "Train labels incorrect.", + ) + self.assertIsNone( + np.testing.assert_array_equal( + labels_test, + np.flip(self.labels)[8:], + ), + "Test labels incorrect.", + ) + + def test_target_property(self): + """Directly instantiate and test TargetProperty.""" + target_property_instance = TargetProperty( + self.X, + self.y, + self.labels, + {}, + ) + self.assertIsInstance( + target_property_instance, + TargetProperty, + "Failed instantiation.", + ) + self.assertFalse( + len(target_property_instance.get_clusters()), + "Clusters was set when it should not have been.", + ) + self.assertTrue( + len(target_property_instance._samples_idxs), + "Sample indices not set.", + ) diff --git a/test/unit/utils/test_sampler_factory.py b/test/unit/utils/test_sampler_factory.py index 9594cfc0..4c34bd12 100644 --- a/test/unit/utils/test_sampler_factory.py +++ b/test/unit/utils/test_sampler_factory.py @@ -4,7 +4,11 @@ import numpy as np -from astartes.samplers import ALL_SAMPLERS, AbstractSampler +from astartes.samplers import ( + ALL_SAMPLERS, + DETERMINISTIC_EXTRAPOLATION_SAMPLERS, + AbstractSampler, +) from astartes.utils.sampler_factory import SamplerFactory @@ -28,7 +32,7 @@ def setUpClass(self): def test_train_test_split(self): """Call sampler factory on all inputs.""" for sampler_name in ALL_SAMPLERS: - if sampler_name in ("scaffold", "time_based"): + if sampler_name in ("scaffold", *DETERMINISTIC_EXTRAPOLATION_SAMPLERS): continue test_factory = SamplerFactory(sampler_name) test_instance = test_factory.get_sampler(self.X, self.y, None, {})