diff --git a/astartes/samplers/interpolation/mlm.py b/astartes/samplers/interpolation/mlm.py deleted file mode 100644 index 273d47e..0000000 --- a/astartes/samplers/interpolation/mlm.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import overload - -import numpy as np - -from astartes.samplers.interpolation import KennardStone - - -class MLM(KennardStone): - # could be convenient to know size of train and test during init... - @overload - def get_sample_idxs(self, n_samples): - """Overload the KennardStone method to permute 10% of samples from train - - Args: - n_samples (int): Number of samples to retrieve. - - Returns: - np.array: The selected indices - """ - if self._current_sample_idx == 0: # permute indexes on the first call - train_idxs = self._samples_idxs[ - self._current_sample_idx : self._current_sample_idx + n_samples - ] - other_idxs = self._samples_idxs[self._current_sample_idx + n_samples : -1] - - # set RNG - rng = np.random.default_rng(seed=self.get_config("random_state")) - n_to_permute = np.floor(0.1 * len(train_idxs)) - train_permute_idxs = rng.choice(train_idxs, n_to_permute) - remaining_train_idxs = [ - i for i in train_idxs if i not in train_permute_idxs - ] - other_permute_idxs = rng.choice(other_idxs, n_to_permute) - remaining_other_idxs = [ - i for i in other_idxs if i not in other_permute_idxs - ] - # reassamble the indexes - self._samples_idxs = np.hstack( - ( - remaining_train_idxs, - other_permute_idxs, - remaining_other_idxs, - train_permute_idxs, - ) - ) - return super().get_sample_idxs(n_samples) diff --git a/examples/morais_lima_martin_sampling/mlm_sampler.ipynb b/examples/morais_lima_martin_sampling/mlm_sampler.ipynb new file mode 100644 index 0000000..c8f0f87 --- /dev/null +++ b/examples/morais_lima_martin_sampling/mlm_sampler.ipynb @@ -0,0 +1,90 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Implementing the Morais-Lima-Martin (MLM) Sampler\n", + "The notebook shows a brief demonstration of using the built in utilities in `astartes` to implement the Morais-Lima-Martin sampler, which you can read about [here](https://academic.oup.com/bioinformatics/article/35/24/5257/5497250)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`astartes` has a very fast implementation of the Kennard-Stone algorithm, on which the MLM sampler is based, available in its `utils`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from astartes.utils.fast_kennard_stone import fast_kennard_stone" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The MLM sampler can then be implemented as shown below.\n", + "The `mlm_sampler` functions takes a 2D array and splits it first using the Kennard-Stone algorithm, then permutes the indices according to the MLM algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from scipy.spatial.distance import pdist, squareform\n", + "import numpy as np\n", + "\n", + "from astartes.samplers.interpolation import KennardStone\n", + "\n", + "\n", + "def mlm_split(X: np.ndarray, *, train_size: float = 0.8, val_size: float = 0.1, test_size: float = 0.1, random_state: int = 42):\n", + " # calculate the distance matrix\n", + " ks_indexes = fast_kennard_stone(squareform(pdist(X, \"euclidean\")))\n", + " pivot = int(len(ks_indexes) * train_size)\n", + " train_idxs = ks_indexes[0:pivot]\n", + " other_idxs = ks_indexes[pivot:]\n", + "\n", + " # set RNG\n", + " rng = np.random.default_rng(seed=random_state)\n", + " \n", + " # choose 10% of train to switch with 10% of val/test\n", + " n_to_permute = np.floor(0.1 * len(train_idxs))\n", + " train_permute_idxs = rng.choice(train_idxs, n_to_permute)\n", + " remaining_train_idxs = filter(lambda i: i not in train_permute_idxs, train_idxs)\n", + " other_permute_idxs = rng.choice(other_idxs, n_to_permute)\n", + " remaining_other_idxs = filter(lambda i: i not in other_permute_idxs, other_idxs)\n", + "\n", + " # reassemble the new lists of indexes\n", + " new_train_idxs = np.concatenate(remaining_train_idxs, other_permute_idxs)\n", + " new_other_idxs = np.concatenate(train_permute_idxs, remaining_other_idxs)\n", + " n_val = int(len(new_other_idxs) * (val_size / (val_size + test_size)))\n", + " val_indexes = new_other_idxs[0:n_val]\n", + " test_indexes = new_other_idxs[n_val:]\n", + " \n", + " # return the split up array\n", + " return X[train_idxs], X[val_indexes], X[test_indexes]\n", + " " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fprop", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}