diff --git a/aimet_ml/model_selection/__init__.py b/aimet_ml/model_selection/__init__.py index f90482a..8b61eff 100644 --- a/aimet_ml/model_selection/__init__.py +++ b/aimet_ml/model_selection/__init__.py @@ -1 +1 @@ -from .splits import get_splitter, join_cols, split_dataset, stratified_group_split +from .splits import get_splitter, join_cols, split_dataset, split_dataset_v2, stratified_group_split diff --git a/aimet_ml/model_selection/splits.py b/aimet_ml/model_selection/splits.py index 7073ed5..84317d7 100644 --- a/aimet_ml/model_selection/splits.py +++ b/aimet_ml/model_selection/splits.py @@ -183,3 +183,60 @@ def split_dataset( data_splits[val_split_name_format.format(k)] = dev_dataset_df.iloc[val_rows].reset_index(drop=True) return data_splits + + +def split_dataset_v2( + dataset_df: pd.DataFrame, + val_fraction: Union[float, int] = 0.1, + test_n_splits: int = 5, + stratify_cols: Optional[Collection[str]] = None, + group_cols: Optional[Collection[str]] = None, + train_split_name_format: str = "train_fold_{}", + val_split_name_format: str = "val_fold_{}", + test_split_name_format: str = "test_fold_{}", + random_seed: int = 1414, +) -> Dict[str, pd.DataFrame]: + """ + Split a dataset into k-fold cross-validation sets with stratification and grouping. + + Args: + dataset_df (pd.DataFrame): The input DataFrame to be split. + val_fraction (Union[float, int], optional): The fraction of data to be used for validation. + If a float is given, it's rounded to the nearest fraction. + If an integer (n) is given, the fraction is calculated as 1/n. + Defaults to 0.2. + test_n_splits (int, optional): Number of cross-validation splits. Defaults to 5. + stratify_cols (Collection[str], optional): Column names for stratification. Defaults to None. + group_cols (Collection[str], optional): Column names for grouping. Defaults to None. + train_split_name_format (str, optional): Format for naming training splits. Defaults to "train_fold_{}". + val_split_name_format (str, optional): Format for naming validation splits. Defaults to "val_fold_{}". + test_split_name_format (str, optional): Format for naming validation splits. Defaults to "test_fold_{}". + random_seed (int, optional): Random seed for reproducibility. Defaults to 1414. + + Returns: + Dict[str, pd.DataFrame]: A dictionary containing the split DataFrames. + """ + if test_n_splits <= 1: + raise ValueError("test_n_splits must be greater than 1") + + data_splits = dict() + + # cross-validation split + k_fold_splitter = get_splitter(stratify_cols, group_cols, test_n_splits, random_seed) + + stratify = join_cols(dataset_df, stratify_cols) if stratify_cols else None + groups = join_cols(dataset_df, group_cols) if group_cols else None + + for n, (dev_rows, test_rows) in enumerate(k_fold_splitter.split(X=dataset_df, y=stratify, groups=groups)): + k = n + 1 + data_splits[test_split_name_format.format(k)] = dataset_df.iloc[test_rows].reset_index(drop=True) + + # split into dev and test datasets + dev_dataset_df = dataset_df.iloc[dev_rows].reset_index(drop=True) + train_dataset_df, val_dataset_df = stratified_group_split( + dev_dataset_df, val_fraction, stratify_cols, group_cols, random_seed + ) + data_splits[train_split_name_format.format(k)] = train_dataset_df + data_splits[val_split_name_format.format(k)] = val_dataset_df + + return data_splits diff --git a/tests/model_selection/test_splits.py b/tests/model_selection/test_splits.py index 6c10884..41af3be 100644 --- a/tests/model_selection/test_splits.py +++ b/tests/model_selection/test_splits.py @@ -5,7 +5,7 @@ import pytest from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold -from aimet_ml.model_selection import get_splitter, join_cols, split_dataset, stratified_group_split +from aimet_ml.model_selection import get_splitter, join_cols, split_dataset, split_dataset_v2, stratified_group_split def validate_splits( @@ -255,9 +255,91 @@ def test_split_dataset( val_split_name = val_split_name_format.format(k) assert train_split_name in data_splits.keys() assert val_split_name in data_splits.keys() - train_df = data_splits[dev_split_name] - val_df = data_splits[test_split_name] - validate_splits(test_fraction, stratify_cols, group_cols, train_df, val_df) + train_df = data_splits[train_split_name] + val_df = data_splits[val_split_name] + validate_splits(1 / val_n_splits, stratify_cols, group_cols, train_df, val_df) + + +@pytest.mark.parametrize( + "val_fraction, test_n_splits, stratify_cols, group_cols, \ + train_split_name_format, val_split_name_format, test_split_name_format, expectation", + [ + (0.25, 5, ['stratify'], ['group'], 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()), + (4, 5, ['stratify'], ['group'], 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()), + (0.25, 5, ['stratify'], None, 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()), + (4, 5, ['stratify'], None, 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()), + (0.25, 5, None, ['group'], 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()), + (4, 5, None, ['group'], 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()), + (0.25, 5, None, None, 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()), + (4, 5, None, None, 'train-fold{}', 'val-fold{}', 'test-fold{}', does_not_raise()), + ( + 0.25, + 1, + None, + None, + 'train-fold{}', + 'val-fold{}', + 'test-fold{}', + pytest.raises(ValueError, match="test_n_splits must be greater than 1"), + ), + ], +) +def test_split_dataset_v2( + val_fraction: Union[float, int], + test_n_splits: int, + stratify_cols: Optional[Collection[str]], + group_cols: Optional[Collection[str]], + train_split_name_format: str, + val_split_name_format: str, + test_split_name_format: str, + expectation: Any, + sample_df: pd.DataFrame, +): + """ + Test the split_dataset_v2 function. + + Args: + val_fraction (Union[float, int], optional): The fraction of data to be used for validation. + If a float is given, it's rounded to the nearest fraction. + If an integer (n) is given, the fraction is calculated as 1/n. + test_n_splits (int, optional): Number of cross-validation splits. + stratify_cols (Collection[str], optional): Column names for stratification. + group_cols (Collection[str], optional): Column names for grouping. + train_split_name_format (str): Format for naming training splits. + val_split_name_format (str): Format for naming validation splits. + test_split_name_format (str): Format for naming test splits. + expectation (Any): The expected outcome of the test. + sample_df (pd.DataFrame): A sample DataFrame. + """ + with expectation: + data_splits = split_dataset_v2( + sample_df, + val_fraction, + test_n_splits, + stratify_cols, + group_cols, + train_split_name_format, + val_split_name_format, + test_split_name_format, + ) + + for n in range(test_n_splits): + k = n + 1 + train_split_name = train_split_name_format.format(k) + val_split_name = val_split_name_format.format(k) + test_split_name = test_split_name_format.format(k) + + assert train_split_name in data_splits.keys() + assert val_split_name in data_splits.keys() + assert test_split_name in data_splits.keys() + + train_df = data_splits[train_split_name] + val_df = data_splits[val_split_name] + test_df = data_splits[test_split_name] + dev_df = pd.concat([train_df, val_df], axis=0, ignore_index=True) + + validate_splits(1 / test_n_splits, stratify_cols, group_cols, dev_df, test_df) + validate_splits(val_fraction, stratify_cols, group_cols, train_df, val_df) if __name__ == "__main__":