Skip to content

Commit

Permalink
Merge pull request #7 from aimet-tech/feature/new-split-method
Browse files Browse the repository at this point in the history
Add new split method
  • Loading branch information
aimet-pasitpk authored Mar 6, 2024
2 parents ae876e4 + a2f57aa commit 65f6587
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 5 deletions.
2 changes: 1 addition & 1 deletion aimet_ml/model_selection/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .splits import get_splitter, join_cols, split_dataset, stratified_group_split
from .splits import get_splitter, join_cols, split_dataset, split_dataset_v2, stratified_group_split
57 changes: 57 additions & 0 deletions aimet_ml/model_selection/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,60 @@ def split_dataset(
data_splits[val_split_name_format.format(k)] = dev_dataset_df.iloc[val_rows].reset_index(drop=True)

return data_splits


def split_dataset_v2(
dataset_df: pd.DataFrame,
val_fraction: Union[float, int] = 0.1,
test_n_splits: int = 5,
stratify_cols: Optional[Collection[str]] = None,
group_cols: Optional[Collection[str]] = None,
train_split_name_format: str = "train_fold_{}",
val_split_name_format: str = "val_fold_{}",
test_split_name_format: str = "test_fold_{}",
random_seed: int = 1414,
) -> Dict[str, pd.DataFrame]:
"""
Split a dataset into k-fold cross-validation sets with stratification and grouping.
Args:
dataset_df (pd.DataFrame): The input DataFrame to be split.
val_fraction (Union[float, int], optional): The fraction of data to be used for validation.
If a float is given, it's rounded to the nearest fraction.
If an integer (n) is given, the fraction is calculated as 1/n.
Defaults to 0.2.
test_n_splits (int, optional): Number of cross-validation splits. Defaults to 5.
stratify_cols (Collection[str], optional): Column names for stratification. Defaults to None.
group_cols (Collection[str], optional): Column names for grouping. Defaults to None.
train_split_name_format (str, optional): Format for naming training splits. Defaults to "train_fold_{}".
val_split_name_format (str, optional): Format for naming validation splits. Defaults to "val_fold_{}".
test_split_name_format (str, optional): Format for naming validation splits. Defaults to "test_fold_{}".
random_seed (int, optional): Random seed for reproducibility. Defaults to 1414.
Returns:
Dict[str, pd.DataFrame]: A dictionary containing the split DataFrames.
"""
if test_n_splits <= 1:
raise ValueError("test_n_splits must be greater than 1")

data_splits = dict()

# cross-validation split
k_fold_splitter = get_splitter(stratify_cols, group_cols, test_n_splits, random_seed)

stratify = join_cols(dataset_df, stratify_cols) if stratify_cols else None
groups = join_cols(dataset_df, group_cols) if group_cols else None

for n, (dev_rows, test_rows) in enumerate(k_fold_splitter.split(X=dataset_df, y=stratify, groups=groups)):
k = n + 1
data_splits[test_split_name_format.format(k)] = dataset_df.iloc[test_rows].reset_index(drop=True)

# split into dev and test datasets
dev_dataset_df = dataset_df.iloc[dev_rows].reset_index(drop=True)
train_dataset_df, val_dataset_df = stratified_group_split(
dev_dataset_df, val_fraction, stratify_cols, group_cols, random_seed
)
data_splits[train_split_name_format.format(k)] = train_dataset_df
data_splits[val_split_name_format.format(k)] = val_dataset_df

return data_splits
90 changes: 86 additions & 4 deletions tests/model_selection/test_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold

from aimet_ml.model_selection import get_splitter, join_cols, split_dataset, stratified_group_split
from aimet_ml.model_selection import get_splitter, join_cols, split_dataset, split_dataset_v2, stratified_group_split


def validate_splits(
Expand Down Expand Up @@ -255,9 +255,91 @@ def test_split_dataset(
val_split_name = val_split_name_format.format(k)
assert train_split_name in data_splits.keys()
assert val_split_name in data_splits.keys()
train_df = data_splits[dev_split_name]
val_df = data_splits[test_split_name]
validate_splits(test_fraction, stratify_cols, group_cols, train_df, val_df)
train_df = data_splits[train_split_name]
val_df = data_splits[val_split_name]
validate_splits(1 / val_n_splits, stratify_cols, group_cols, train_df, val_df)


@pytest.mark.parametrize(
"val_fraction, test_n_splits, stratify_cols, group_cols, \
train_split_name_format, val_split_name_format, test_split_name_format, expectation",
[
(0.25, 5, ['stratify'], ['group'], 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()),
(4, 5, ['stratify'], ['group'], 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()),
(0.25, 5, ['stratify'], None, 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()),
(4, 5, ['stratify'], None, 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()),
(0.25, 5, None, ['group'], 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()),
(4, 5, None, ['group'], 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()),
(0.25, 5, None, None, 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()),
(4, 5, None, None, 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()),
(
0.25,
1,
None,
None,
'train-fold{}',
'val-fold{}',
'test-fold{}',
pytest.raises(ValueError, match="test_n_splits must be greater than 1"),
),
],
)
def test_split_dataset_v2(
val_fraction: Union[float, int],
test_n_splits: int,
stratify_cols: Optional[Collection[str]],
group_cols: Optional[Collection[str]],
train_split_name_format: str,
val_split_name_format: str,
test_split_name_format: str,
expectation: Any,
sample_df: pd.DataFrame,
):
"""
Test the split_dataset_v2 function.
Args:
val_fraction (Union[float, int], optional): The fraction of data to be used for validation.
If a float is given, it's rounded to the nearest fraction.
If an integer (n) is given, the fraction is calculated as 1/n.
test_n_splits (int, optional): Number of cross-validation splits.
stratify_cols (Collection[str], optional): Column names for stratification.
group_cols (Collection[str], optional): Column names for grouping.
train_split_name_format (str): Format for naming training splits.
val_split_name_format (str): Format for naming validation splits.
test_split_name_format (str): Format for naming test splits.
expectation (Any): The expected outcome of the test.
sample_df (pd.DataFrame): A sample DataFrame.
"""
with expectation:
data_splits = split_dataset_v2(
sample_df,
val_fraction,
test_n_splits,
stratify_cols,
group_cols,
train_split_name_format,
val_split_name_format,
test_split_name_format,
)

for n in range(test_n_splits):
k = n + 1
train_split_name = train_split_name_format.format(k)
val_split_name = val_split_name_format.format(k)
test_split_name = test_split_name_format.format(k)

assert train_split_name in data_splits.keys()
assert val_split_name in data_splits.keys()
assert test_split_name in data_splits.keys()

train_df = data_splits[train_split_name]
val_df = data_splits[val_split_name]
test_df = data_splits[test_split_name]
dev_df = pd.concat([train_df, val_df], axis=0, ignore_index=True)

validate_splits(1 / test_n_splits, stratify_cols, group_cols, dev_df, test_df)
validate_splits(val_fraction, stratify_cols, group_cols, train_df, val_df)


if __name__ == "__main__":
Expand Down

0 comments on commit 65f6587

Please sign in to comment.