diff --git a/.github/workflows/CI-python.yml b/.github/workflows/CI-python.yml index fcaae683e0..b9d9485a9d 100644 --- a/.github/workflows/CI-python.yml +++ b/.github/workflows/CI-python.yml @@ -58,6 +58,12 @@ jobs: pip install -v -e . working-directory: ${{ matrix.packageDirectory }} + - if: ${{ (matrix.packageDirectory == 'erroranalysis') || (matrix.packageDirectory == 'responsibleai') }} + name: Install rai_test_utils locally until next version is released + run: | + pip install -v -e . + working-directory: rai_test_utils + - name: Pip freeze run: | pip freeze > installed-requirements-dev.txt diff --git a/apps/dashboard/src/app/textApplications.ts b/apps/dashboard/src/app/textApplications.ts index 4adb44754e..56c072dbc8 100644 --- a/apps/dashboard/src/app/textApplications.ts +++ b/apps/dashboard/src/app/textApplications.ts @@ -17,6 +17,7 @@ import { emotionModelExplanationData } from "../model-assessment-text/__mock_data__/emotion"; import { squad } from "../model-assessment-text/__mock_data__/squad"; +import { squadGenai } from "../model-assessment-text/__mock_data__/squadGenai"; import { IDataSet, @@ -65,6 +66,10 @@ export const textApplications: ITextApplications = { squad: { classDimension: 3, dataset: squad + } as IModelAssessmentDataSet, + squadGenai: { + classDimension: 3, + dataset: squadGenai } as IModelAssessmentDataSet }, versions: { "1": 1, "2:Static-View": 2 } diff --git a/apps/dashboard/src/model-assessment-forecasting/__mock_data__/mockForecastingDataSingleTimeSeries.ts b/apps/dashboard/src/model-assessment-forecasting/__mock_data__/mockForecastingDataSingleTimeSeries.ts index 6457722b82..fe807c0f87 100644 --- a/apps/dashboard/src/model-assessment-forecasting/__mock_data__/mockForecastingDataSingleTimeSeries.ts +++ b/apps/dashboard/src/model-assessment-forecasting/__mock_data__/mockForecastingDataSingleTimeSeries.ts @@ -28,7 +28,9 @@ mockForecastingDataSingleTimeSeries.predicted_y = startingIndexBobsSandwichesTimeSeries, endingIndexBobsSandwichesTimeSeries ); -mockForecastingDataSingleTimeSeries.true_y = mockForecastingData.true_y.slice( - startingIndexBobsSandwichesTimeSeries, - endingIndexBobsSandwichesTimeSeries -); +if (mockForecastingData.true_y) { + mockForecastingDataSingleTimeSeries.true_y = mockForecastingData.true_y.slice( + startingIndexBobsSandwichesTimeSeries, + endingIndexBobsSandwichesTimeSeries + ); +} diff --git a/apps/dashboard/src/model-assessment-text/__mock_data__/squadGenai.ts b/apps/dashboard/src/model-assessment-text/__mock_data__/squadGenai.ts new file mode 100644 index 0000000000..52b0007b04 --- /dev/null +++ b/apps/dashboard/src/model-assessment-text/__mock_data__/squadGenai.ts @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { DatasetTaskType, IDataset } from "@responsible-ai/core-ui"; + +export const squadGenai: IDataset = { + categorical_features: [], + class_names: undefined, + feature_names: [ + "context", + "prompt", + "positive_words", + "negative_words", + "negation_words", + "negated_entities", + "named_persons", + "sentence_length" + ], + features: [ + [ + 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', + 'Answer the question given the context.\n\ncontext:\nArchitecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.\n\nquestion:\nTo whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?', + 50, + 0, + 0, + 0, + 3, + 827 + ], + [ + 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', + 'Answer the question given the context.\n\ncontext:\nArchitecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.\n\nquestion:\nWhat is in front of the Notre Dame Main Building?', + 50, + 0, + 0, + 0, + 2, + 805 + ], + [ + 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', + 'Answer the question given the context.\n\ncontext:\nArchitecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.\n\nquestion:\nThe Basilica of the Sacred heart at Notre Dame is beside to which structure?', + 52, + 0, + 0, + 0, + 3, + 832 + ] + ], + predicted_y: [ + "This is a dummy answer", + "This is a dummy answer", + "This is a dummy answer" + ], + target_column: undefined, + task_type: DatasetTaskType.GenerativeText, + true_y: undefined +}; diff --git a/libs/core-ui/src/lib/DatasetCohort.ts b/libs/core-ui/src/lib/DatasetCohort.ts index 74f92049ed..4753abed90 100644 --- a/libs/core-ui/src/lib/DatasetCohort.ts +++ b/libs/core-ui/src/lib/DatasetCohort.ts @@ -120,7 +120,7 @@ export class DatasetCohort { dataDict[index][featureName] = val; }); }); - this.dataset.true_y.forEach((val, index) => { + this.dataset.true_y?.forEach((val, index) => { if (Array.isArray(val)) { val.forEach((subVal, subIndex) => { dataDict[index][DatasetCohortColumns.TrueY + subIndex.toString()] = diff --git a/libs/core-ui/src/lib/Interfaces/IDataset.ts b/libs/core-ui/src/lib/Interfaces/IDataset.ts index ea05d3df2e..832c8d4e01 100644 --- a/libs/core-ui/src/lib/Interfaces/IDataset.ts +++ b/libs/core-ui/src/lib/Interfaces/IDataset.ts @@ -13,7 +13,8 @@ export enum DatasetTaskType { MultilabelImageClassification = "multilabel_image_classification", Forecasting = "forecasting", ObjectDetection = "object_detection", - QuestionAnswering = "question_answering" + QuestionAnswering = "question_answering", + GenerativeText = "generative_text" } export interface ITabularDatasetMetadata { @@ -31,7 +32,7 @@ export interface IObjectDetectionLabelType { export interface IDataset { task_type: DatasetTaskType; - true_y: number[] | number[][] | string[]; + true_y?: number[] | number[][] | string[]; predicted_y?: number[] | number[][] | string[]; probability_y?: number[][]; features: unknown[][]; diff --git a/libs/core-ui/src/lib/util/datasetUtils/getColumnRanges.ts b/libs/core-ui/src/lib/util/datasetUtils/getColumnRanges.ts index c1b84bd144..61fde5cabf 100644 --- a/libs/core-ui/src/lib/util/datasetUtils/getColumnRanges.ts +++ b/libs/core-ui/src/lib/util/datasetUtils/getColumnRanges.ts @@ -161,7 +161,11 @@ function getRegressionErrorFeatureRange( dataset: IDataset, modelType: ModelTypes ): IColumnRange | undefined { - if (modelType === ModelTypes.Regression && dataset.predicted_y) { + if ( + modelType === ModelTypes.Regression && + dataset.predicted_y && + dataset.true_y + ) { const regressionErrors = []; for (let index = 0; index < dataset.features.length; index++) { const trueY = dataset.true_y[index]; diff --git a/libs/core-ui/src/lib/util/datasetUtils/getPropertyValues.ts b/libs/core-ui/src/lib/util/datasetUtils/getPropertyValues.ts index 638e152929..eeb5bd7477 100644 --- a/libs/core-ui/src/lib/util/datasetUtils/getPropertyValues.ts +++ b/libs/core-ui/src/lib/util/datasetUtils/getPropertyValues.ts @@ -44,9 +44,12 @@ export function getPropertyValues( }); } if (property === DatasetCohortColumns.TrueY) { - return indexes.map((index) => { - return dataset.true_y[index]; - }); + const trueYs = dataset.true_y; + if (trueYs) { + return indexes.map((index) => { + return trueYs[index]; + }); + } } if (dataset.predicted_y && dataset.true_y) { return getErrors(property, indexes, dataset, modelType); @@ -62,6 +65,7 @@ function getErrors( ): unknown[] { if ( dataset.predicted_y && + dataset.true_y && !Array.isArray(dataset.true_y) && !Array.isArray(dataset.predicted_y) ) { diff --git a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/DetectionDetails.tsx b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/DetectionDetails.tsx new file mode 100644 index 0000000000..2daf59ebcb --- /dev/null +++ b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/DetectionDetails.tsx @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { Stack, Text } from "@fluentui/react"; +import { IVisionListItem } from "@responsible-ai/core-ui"; +import { localization } from "@responsible-ai/localization"; +import React from "react"; + +import { stackTokens } from "./FlyoutObjectDetectionUtils"; + +interface IDetectionDetailsProps { + item: IVisionListItem; // replace with actual type + correctDetections: string; + incorrectDetections: string; +} +export class DetectionDetails extends React.Component { + public render(): React.ReactNode { + return ( + + + + + {localization.InterpretVision.Dashboard.indexLabel} + {this.props.item?.index} + + + + + {localization.InterpretVision.Dashboard.correctDetections} + {this.props.correctDetections} + + + + + {localization.InterpretVision.Dashboard.incorrectDetections} + {this.props.incorrectDetections} + + + + ); + } +} diff --git a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/FlyoutObjectDetection.tsx b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/FlyoutObjectDetection.tsx index d32c30cdc6..12909ba053 100644 --- a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/FlyoutObjectDetection.tsx +++ b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/FlyoutObjectDetection.tsx @@ -7,6 +7,7 @@ import { IComboBox, IComboBoxOption, Image, + Label, List, Panel, PanelType, @@ -24,6 +25,7 @@ import * as FlyoutStyles from "../utils/FlyoutUtils"; import { getObjectDetectionImageAltText } from "../utils/getAltTextUtils"; import { getJoinedLabelString } from "../utils/labelUtils"; +import { DetectionDetails } from "./DetectionDetails"; import { flyoutStyles, explanationImage, @@ -55,7 +57,8 @@ export class FlyoutObjectDetection extends React.Component< const selectableObjectIndexes = FlyoutStyles.generateSelectableObjectDetectionIndexes( localization.InterpretVision.Dashboard.prefix, - item + item, + this.props.dataset.class_names ); this.setState({ item, metadata, selectableObjectIndexes }); } @@ -73,7 +76,8 @@ export class FlyoutObjectDetection extends React.Component< const selectableObjectIndexes = FlyoutStyles.generateSelectableObjectDetectionIndexes( localization.InterpretVision.Dashboard.prefix, - item + item, + this.props.dataset.class_names ); this.setState({ item: this.props.item, @@ -122,42 +126,11 @@ export class FlyoutObjectDetection extends React.Component< verticalAlign="center" > - - - - - {localization.InterpretVision.Dashboard.indexLabel} - {item?.index} - - - - - { - localization.InterpretVision.Dashboard - .correctDetections - } - {correctDetections} - - - - - { - localization.InterpretVision.Dashboard - .incorrectDetections - } - {incorrectDetections} - - - + @@ -173,13 +146,17 @@ export class FlyoutObjectDetection extends React.Component< {localization.InterpretVision.Dashboard.panelInformation} - + + @@ -200,17 +177,27 @@ export class FlyoutObjectDetection extends React.Component< - { - - } + + {!this.props.loadingExplanation[item.index][ +this.state.odSelectedKey.slice( diff --git a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/FlyoutObjectDetectionUtils.tsx b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/FlyoutObjectDetectionUtils.tsx index 192a05e0aa..6de3e33643 100644 --- a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/FlyoutObjectDetectionUtils.tsx +++ b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/FlyoutObjectDetectionUtils.tsx @@ -1,9 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -import { IComboBoxOption, getTheme } from "@fluentui/react"; +import { IComboBoxOption, getTheme, Stack, Text } from "@fluentui/react"; import { IDataset, IVisionListItem } from "@responsible-ai/core-ui"; import { localization } from "@responsible-ai/localization"; +import React from "react"; import { CanvasTools } from "vott-ct"; import { Editor } from "vott-ct/lib/js/CanvasTools/CanvasTools.Editor"; import { RegionData } from "vott-ct/lib/js/CanvasTools/Core/RegionData"; @@ -31,6 +32,42 @@ export const stackTokens = { large: { childrenGap: "l2" }, medium: { childrenGap: "l1" } }; + +const theme = getTheme(); + +export class ColorLegend extends React.Component { + public render(): React.ReactNode { + return ( + + + + {localization.InterpretVision.Dashboard.trueY} + +
+ + + + {localization.InterpretVision.Dashboard.predictedY} + +
+ + + ); + } +} + export const ExcessLabelLen = localization.InterpretVision.Dashboard.prefix.length; @@ -109,8 +146,6 @@ export function drawBoundingBoxes( return; } - const theme = getTheme(); - // Ensuring object detection labels are populated if ( !dataset.object_detection_predicted_y || diff --git a/libs/interpret-vision/src/lib/VisionExplanationDashboard/utils/FlyoutUtils.tsx b/libs/interpret-vision/src/lib/VisionExplanationDashboard/utils/FlyoutUtils.tsx index 57ce7520bc..b78f482f7c 100644 --- a/libs/interpret-vision/src/lib/VisionExplanationDashboard/utils/FlyoutUtils.tsx +++ b/libs/interpret-vision/src/lib/VisionExplanationDashboard/utils/FlyoutUtils.tsx @@ -34,13 +34,14 @@ export function onRenderCell( export function generateSelectableObjectDetectionIndexes( prefix: string, - item: IVisionListItem | undefined + item: IVisionListItem | undefined, + classNames: string[] | undefined ): IComboBoxOption[] { const temp = item?.odPredictedY; const selectableObjectIndexes: IComboBoxOption[] = []; - if (temp) { + if (temp && classNames) { for (let i = 0; i < Object.values(temp).length; i++) { - const className = item?.predictedY[i]; + const className = classNames[temp[i][0] - 1]; selectableObjectIndexes.push({ key: prefix + String(i), text: String(i) + String(": ") + className diff --git a/libs/localization/src/lib/en.json b/libs/localization/src/lib/en.json index c9d6c44baa..4f98babeb2 100644 --- a/libs/localization/src/lib/en.json +++ b/libs/localization/src/lib/en.json @@ -1477,6 +1477,7 @@ "multiselect": "Multiselect", "notdefined": "object scenario not defined", "objectSelect": "Object Selection", + "objectSelectionLabel": "objectSelectLabel", "pageSize": "Page size: ", "panelTitle": "Selected instance", "panelExplanation": "Explanation", diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/IndividualFeatureImportanceView/TextLocalImportancePlots.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/IndividualFeatureImportanceView/TextLocalImportancePlots.tsx index 2f14c412da..bcccdfdc89 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/IndividualFeatureImportanceView/TextLocalImportancePlots.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/IndividualFeatureImportanceView/TextLocalImportancePlots.tsx @@ -86,7 +86,7 @@ export class TextLocalImportancePlots extends React.Component
diff --git a/notebooks/individual-dashboards/erroranalysis-dashboard/erroranalysis-interpretability-dashboard-census.ipynb b/notebooks/individual-dashboards/erroranalysis-dashboard/erroranalysis-interpretability-dashboard-census.ipynb index d8ad57e2f0..d5ba5c57b4 100644 --- a/notebooks/individual-dashboards/erroranalysis-dashboard/erroranalysis-interpretability-dashboard-census.ipynb +++ b/notebooks/individual-dashboards/erroranalysis-dashboard/erroranalysis-interpretability-dashboard-census.ipynb @@ -118,11 +118,19 @@ "metadata": {}, "outputs": [], "source": [ + "from packaging import version\n", + "import sklearn\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", "from sklearn.compose import ColumnTransformer\n", "\n", + "# for older scikit-learn versions use sparse, for newer sparse_output:\n", + "if version.parse(sklearn.__version__) < version.parse('1.2'):\n", + " ohe_params = {\"sparse\": False}\n", + "else:\n", + " ohe_params = {\"sparse_output\": False}\n", + "\n", "def split_label(dataset):\n", " X = dataset.drop(['income'], axis=1)\n", " y = dataset[['income']]\n", @@ -141,7 +149,7 @@ " ])\n", " cat_pipe = Pipeline([\n", " ('cat_imputer', SimpleImputer(strategy='constant', fill_value='?')),\n", - " ('cat_encoder', OneHotEncoder(handle_unknown='ignore', sparse=False))\n", + " ('cat_encoder', OneHotEncoder(handle_unknown='ignore', **ohe_params))\n", " ])\n", " feat_pipe = ColumnTransformer([\n", " ('num_pipe', num_pipe, pipe_cfg['num_cols']),\n", diff --git a/notebooks/responsibleaidashboard/tabular/responsibleaidashboard-housing-classification-model-debugging.ipynb b/notebooks/responsibleaidashboard/tabular/responsibleaidashboard-housing-classification-model-debugging.ipynb index 490b33a119..ed1c234e2d 100644 --- a/notebooks/responsibleaidashboard/tabular/responsibleaidashboard-housing-classification-model-debugging.ipynb +++ b/notebooks/responsibleaidashboard/tabular/responsibleaidashboard-housing-classification-model-debugging.ipynb @@ -70,12 +70,20 @@ "metadata": {}, "outputs": [], "source": [ + "from packaging import version\n", "from raiutils.dataset import fetch_dataset\n", + "import sklearn\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.preprocessing import OneHotEncoder\n", "from sklearn.compose import ColumnTransformer\n", "\n", + "# for older scikit-learn versions use sparse, for newer sparse_output:\n", + "if version.parse(sklearn.__version__) < version.parse('1.2'):\n", + " ohe_params = {\"sparse\": False}\n", + "else:\n", + " ohe_params = {\"sparse_output\": False}\n", + "\n", "def split_label(dataset, target_feature):\n", " X = dataset.drop([target_feature], axis=1)\n", " y = dataset[[target_feature]]\n", @@ -93,7 +101,7 @@ " ])\n", " cat_pipe = Pipeline([\n", " ('cat_imputer', SimpleImputer(strategy='constant', fill_value='?')),\n", - " ('cat_encoder', OneHotEncoder(handle_unknown='ignore', sparse=False))\n", + " ('cat_encoder', OneHotEncoder(handle_unknown='ignore', **ohe_params))\n", " ])\n", " feat_pipe = ColumnTransformer([\n", " ('num_pipe', num_pipe, pipe_cfg['num_cols']),\n", @@ -179,7 +187,7 @@ "source": [ "To use Responsible AI Dashboard, initialize a RAIInsights object upon which different components can be loaded.\n", "\n", - "RAIInsights accepts the model, the full dataset, the test dataset, the target feature string and the task type string as its arguments.", + "RAIInsights accepts the model, the full dataset, the test dataset, the target feature string and the task type string as its arguments.\n", "\n", "You may also create the `FeatureMetadata` container, identify any feature of your choice as the `identity_feature`, specify a list of strings of categorical feature names via the `categorical_features` parameter, and specify dropped features via the `dropped_features` parameter. The `FeatureMetadata` may also be passed into the `RAIInsights`." ] diff --git a/notebooks/responsibleaidashboard/tabular/responsibleaidashboard-housing-decision-making.ipynb b/notebooks/responsibleaidashboard/tabular/responsibleaidashboard-housing-decision-making.ipynb index 333e9a2b29..7beafba689 100644 --- a/notebooks/responsibleaidashboard/tabular/responsibleaidashboard-housing-decision-making.ipynb +++ b/notebooks/responsibleaidashboard/tabular/responsibleaidashboard-housing-decision-making.ipynb @@ -59,12 +59,20 @@ "metadata": {}, "outputs": [], "source": [ + "from packaging import version\n", "from raiutils.dataset import fetch_dataset\n", + "import sklearn\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", "from sklearn.compose import ColumnTransformer\n", "\n", + "# for older scikit-learn versions use sparse, for newer sparse_output:\n", + "if version.parse(sklearn.__version__) < version.parse('1.2'):\n", + " ohe_params = {\"sparse\": False}\n", + "else:\n", + " ohe_params = {\"sparse_output\": False}\n", + "\n", "def split_label(dataset, target_feature):\n", " X = dataset.drop([target_feature], axis=1)\n", " y = dataset[[target_feature]]\n", @@ -83,7 +91,7 @@ " ])\n", " cat_pipe = Pipeline([\n", " ('cat_imputer', SimpleImputer(strategy='constant', fill_value='?')),\n", - " ('cat_encoder', OneHotEncoder(handle_unknown='ignore', sparse=False))\n", + " ('cat_encoder', OneHotEncoder(handle_unknown='ignore', **ohe_params))\n", " ])\n", " feat_pipe = ColumnTransformer([\n", " ('num_pipe', num_pipe, pipe_cfg['num_cols']),\n", @@ -148,7 +156,7 @@ "source": [ "To use Responsible AI Dashboard, initialize a RAIInsights object upon which different components can be loaded.\n", "\n", - "RAIInsights accepts the model, the full dataset, the test dataset, the target feature string and the task type string as its arguments.", + "RAIInsights accepts the model, the full dataset, the test dataset, the target feature string and the task type string as its arguments.\n", "\n", "You may also create the `FeatureMetadata` container, identify any feature of your choice as the `identity_feature`, specify a list of strings of categorical feature names via the `categorical_features` parameter, and specify dropped features via the `dropped_features` parameter. The `FeatureMetadata` may also be passed into the `RAIInsights`." ] diff --git a/rai_test_utils/rai_test_utils/models/sklearn/sklearn_model_utils.py b/rai_test_utils/rai_test_utils/models/sklearn/sklearn_model_utils.py index 9866fad009..4721d4ce01 100644 --- a/rai_test_utils/rai_test_utils/models/sklearn/sklearn_model_utils.py +++ b/rai_test_utils/rai_test_utils/models/sklearn/sklearn_model_utils.py @@ -3,6 +3,8 @@ import numpy as np import pandas as pd +import sklearn +from packaging import version from sklearn import svm from sklearn.compose import ColumnTransformer from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor @@ -127,6 +129,11 @@ def conv(X): (conv(np.prod(x, axis=1)).reshape(-1, 1), conv(np.prod(x, axis=1)**2).reshape(-1, 1)) )) + # for older scikit-learn versions use sparse, for newer sparse_output: + if version.parse(sklearn.__version__) < version.parse('1.2'): + ohe_params = {"sparse": False} + else: + ohe_params = {"sparse_output": False} transformations = ColumnTransformer([ ("age_fare_1", Pipeline(steps=[ ('imputer', SimpleImputer(strategy='median')), @@ -137,8 +144,8 @@ def conv(X): ("embarked", Pipeline(steps=[ ("imputer", SimpleImputer(strategy='constant', fill_value='missing')), - ("encoder", OneHotEncoder(sparse=False))]), ["embarked"]), - ("sex_pclass", OneHotEncoder(sparse=False), ["sex", "pclass"]) + ("encoder", OneHotEncoder(**ohe_params))]), ["embarked"]), + ("sex_pclass", OneHotEncoder(**ohe_params), ["sex", "pclass"]) ]) clf = Pipeline(steps=[('preprocessor', transformations), ('classifier', diff --git a/raiwidgets/raiwidgets/dashboard.py b/raiwidgets/raiwidgets/dashboard.py index c70f2d4816..c0b00a9da6 100644 --- a/raiwidgets/raiwidgets/dashboard.py +++ b/raiwidgets/raiwidgets/dashboard.py @@ -61,6 +61,7 @@ def __init__(self, *, port, locale, no_inline_dashboard=False, + is_private_link=False, **kwargs): """Initialize the dashboard.""" @@ -68,7 +69,9 @@ def __init__(self, *, raise ValueError("Required parameters not provided") try: - self._service = FlaskHelper(ip=public_ip, port=port) + self._service = FlaskHelper(ip=public_ip, + port=port, + is_private_link=is_private_link) except Exception as e: self._service = None raise e diff --git a/raiwidgets/raiwidgets/responsibleai_dashboard.py b/raiwidgets/raiwidgets/responsibleai_dashboard.py index fc194258ad..3ac0a11fff 100644 --- a/raiwidgets/raiwidgets/responsibleai_dashboard.py +++ b/raiwidgets/raiwidgets/responsibleai_dashboard.py @@ -29,10 +29,14 @@ class ResponsibleAIDashboard(Dashboard): :param cohort_list: List of cohorts defined by the user for the dashboard. :type cohort_list: List[Cohort] + :param is_private_link: If the dashboard environment is + a private link AML workspace. + :type is_private_link: bool """ def __init__(self, analysis: RAIInsights, public_ip=None, port=None, locale=None, - cohort_list=None, **kwargs): + cohort_list=None, is_private_link=False, + **kwargs): self.input = ResponsibleAIDashboardInput( analysis, cohort_list=cohort_list) @@ -43,6 +47,7 @@ def __init__(self, analysis: RAIInsights, port=port, locale=locale, no_inline_dashboard=True, + is_private_link=is_private_link, **kwargs) def predict(): diff --git a/raiwidgets/requirements.txt b/raiwidgets/requirements.txt index 0bac7cde2a..be5d8cddb7 100644 --- a/raiwidgets/requirements.txt +++ b/raiwidgets/requirements.txt @@ -1,7 +1,7 @@ numpy>=1.17.2,<=1.26.2 pandas>=0.25.1,<2.0.0 scipy>=1.4.1 -rai-core-flask==0.7.2 +rai-core-flask==0.7.3 itsdangerous<=2.1.2 scikit-learn>=0.22.1 lightgbm>=2.0.11