generated from JacksonBurns/blank-python-project
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
move the MLM sampler into a demo notebook
- Loading branch information
1 parent
28e8d9c
commit b794dff
Showing
2 changed files
with
90 additions
and
46 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Implementing the Morais-Lima-Martin (MLM) Sampler\n", | ||
"The notebook shows a brief demonstration of using the built in utilities in `astartes` to implement the Morais-Lima-Martin sampler, which you can read about [here](https://academic.oup.com/bioinformatics/article/35/24/5257/5497250)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"`astartes` has a very fast implementation of the Kennard-Stone algorithm, on which the MLM sampler is based, available in its `utils`:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from astartes.utils.fast_kennard_stone import fast_kennard_stone" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The MLM sampler can then be implemented as shown below.\n", | ||
"The `mlm_sampler` functions takes a 2D array and splits it first using the Kennard-Stone algorithm, then permutes the indices according to the MLM algorithm." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from scipy.spatial.distance import pdist, squareform\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"from astartes.samplers.interpolation import KennardStone\n", | ||
"\n", | ||
"\n", | ||
"def mlm_split(X: np.ndarray, *, train_size: float = 0.8, val_size: float = 0.1, test_size: float = 0.1, random_state: int = 42):\n", | ||
" # calculate the distance matrix\n", | ||
" ks_indexes = fast_kennard_stone(squareform(pdist(X, \"euclidean\")))\n", | ||
" pivot = int(len(ks_indexes) * train_size)\n", | ||
" train_idxs = ks_indexes[0:pivot]\n", | ||
" other_idxs = ks_indexes[pivot:]\n", | ||
"\n", | ||
" # set RNG\n", | ||
" rng = np.random.default_rng(seed=random_state)\n", | ||
" \n", | ||
" # choose 10% of train to switch with 10% of val/test\n", | ||
" n_to_permute = np.floor(0.1 * len(train_idxs))\n", | ||
" train_permute_idxs = rng.choice(train_idxs, n_to_permute)\n", | ||
" remaining_train_idxs = filter(lambda i: i not in train_permute_idxs, train_idxs)\n", | ||
" other_permute_idxs = rng.choice(other_idxs, n_to_permute)\n", | ||
" remaining_other_idxs = filter(lambda i: i not in other_permute_idxs, other_idxs)\n", | ||
"\n", | ||
" # reassemble the new lists of indexes\n", | ||
" new_train_idxs = np.concatenate(remaining_train_idxs, other_permute_idxs)\n", | ||
" new_other_idxs = np.concatenate(train_permute_idxs, remaining_other_idxs)\n", | ||
" n_val = int(len(new_other_idxs) * (val_size / (val_size + test_size)))\n", | ||
" val_indexes = new_other_idxs[0:n_val]\n", | ||
" test_indexes = new_other_idxs[n_val:]\n", | ||
" \n", | ||
" # return the split up array\n", | ||
" return X[train_idxs], X[val_indexes], X[test_indexes]\n", | ||
" " | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "fprop", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python", | ||
"version": "3.11.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |