From f106bde093a04c9eee34129eb2821a8490dc5186 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 14 Feb 2024 17:08:24 -0500 Subject: [PATCH] fix builds due to OneHotEncoder sparse parameter breaking change in scikit-learn --- .github/workflows/CD.yml | 4 ++++ ...shboard-census-classification-model-debugging.ipynb | 10 +++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CD.yml b/.github/workflows/CD.yml index 12ad87e94e..8207e811a4 100644 --- a/.github/workflows/CD.yml +++ b/.github/workflows/CD.yml @@ -84,6 +84,10 @@ jobs: pip install -r requirements.txt pip install -r requirements-dev.txt working-directory: ${{ env.widgetDirectory }} + - name: Install rai_test_utils locally until next version is released + run: | + pip install -v -e . + working-directory: rai_test_utils - name: pip freeze run: pip freeze - name: replace README for raiwidgets diff --git a/notebooks/responsibleaidashboard/tabular/responsibleaidashboard-census-classification-model-debugging.ipynb b/notebooks/responsibleaidashboard/tabular/responsibleaidashboard-census-classification-model-debugging.ipynb index efc8333ef0..b7a99bc4cd 100644 --- a/notebooks/responsibleaidashboard/tabular/responsibleaidashboard-census-classification-model-debugging.ipynb +++ b/notebooks/responsibleaidashboard/tabular/responsibleaidashboard-census-classification-model-debugging.ipynb @@ -88,6 +88,8 @@ "outputs": [], "source": [ "from raiutils.dataset import fetch_dataset\n", + "import sklearn\n", + "from packaging import version\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", @@ -98,6 +100,12 @@ " y = dataset[[target_feature]]\n", " return X, y\n", "\n", + "# for older scikit-learn versions use sparse, for newer sparse_output:\n", + "if version.parse(sklearn.__version__) < version.parse('1.2'):\n", + " ohe_params = {\"sparse\": False}\n", + "else:\n", + " ohe_params = {\"sparse_output\": False}\n", + "\n", "def create_classification_pipeline(X):\n", " pipe_cfg = {\n", " 'num_cols': X.dtypes[X.dtypes == 'int64'].index.values.tolist(),\n", @@ -109,7 +117,7 @@ " ])\n", " cat_pipe = Pipeline([\n", " ('cat_imputer', SimpleImputer(strategy='constant', fill_value='?')),\n", - " ('cat_encoder', OneHotEncoder(handle_unknown='ignore', sparse=False))\n", + " ('cat_encoder', OneHotEncoder(handle_unknown='ignore', **ohe_params))\n", " ])\n", " feat_pipe = ColumnTransformer([\n", " ('num_pipe', num_pipe, pipe_cfg['num_cols']),\n",