Skip to content

Commit

Permalink
move the MLM sampler into a demo notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
JacksonBurns committed Aug 13, 2024
1 parent 28e8d9c commit b794dff
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 46 deletions.
46 changes: 0 additions & 46 deletions astartes/samplers/interpolation/mlm.py

This file was deleted.

90 changes: 90 additions & 0 deletions examples/morais_lima_martin_sampling/mlm_sampler.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Implementing the Morais-Lima-Martin (MLM) Sampler\n",
"The notebook shows a brief demonstration of using the built in utilities in `astartes` to implement the Morais-Lima-Martin sampler, which you can read about [here](https://academic.oup.com/bioinformatics/article/35/24/5257/5497250)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`astartes` has a very fast implementation of the Kennard-Stone algorithm, on which the MLM sampler is based, available in its `utils`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from astartes.utils.fast_kennard_stone import fast_kennard_stone"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The MLM sampler can then be implemented as shown below.\n",
"The `mlm_sampler` functions takes a 2D array and splits it first using the Kennard-Stone algorithm, then permutes the indices according to the MLM algorithm."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from scipy.spatial.distance import pdist, squareform\n",
"import numpy as np\n",
"\n",
"from astartes.samplers.interpolation import KennardStone\n",
"\n",
"\n",
"def mlm_split(X: np.ndarray, *, train_size: float = 0.8, val_size: float = 0.1, test_size: float = 0.1, random_state: int = 42):\n",
" # calculate the distance matrix\n",
" ks_indexes = fast_kennard_stone(squareform(pdist(X, \"euclidean\")))\n",
" pivot = int(len(ks_indexes) * train_size)\n",
" train_idxs = ks_indexes[0:pivot]\n",
" other_idxs = ks_indexes[pivot:]\n",
"\n",
" # set RNG\n",
" rng = np.random.default_rng(seed=random_state)\n",
" \n",
" # choose 10% of train to switch with 10% of val/test\n",
" n_to_permute = np.floor(0.1 * len(train_idxs))\n",
" train_permute_idxs = rng.choice(train_idxs, n_to_permute)\n",
" remaining_train_idxs = filter(lambda i: i not in train_permute_idxs, train_idxs)\n",
" other_permute_idxs = rng.choice(other_idxs, n_to_permute)\n",
" remaining_other_idxs = filter(lambda i: i not in other_permute_idxs, other_idxs)\n",
"\n",
" # reassemble the new lists of indexes\n",
" new_train_idxs = np.concatenate(remaining_train_idxs, other_permute_idxs)\n",
" new_other_idxs = np.concatenate(train_permute_idxs, remaining_other_idxs)\n",
" n_val = int(len(new_other_idxs) * (val_size / (val_size + test_size)))\n",
" val_indexes = new_other_idxs[0:n_val]\n",
" test_indexes = new_other_idxs[n_val:]\n",
" \n",
" # return the split up array\n",
" return X[train_idxs], X[val_indexes], X[test_indexes]\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "fprop",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit b794dff

Please sign in to comment.