From 12bfef87821a029690317e9f8c08acf39be5123f Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Mon, 29 Jan 2024 11:37:02 -0500 Subject: [PATCH] add genai metrics endpoint in UI for model overview metrics (#2517) --- apps/widget/src/app/ModelAssessment.tsx | 14 ++ apps/widget/src/app/ModelAssessmentUtils.tsx | 1 + libs/core-ui/src/index.ts | 1 + .../lib/Context/ModelAssessmentContext.tsx | 7 + .../src/lib/Interfaces/IExplanationContext.ts | 1 + .../lib/util/GenerativeTextStatisticsUtils.ts | 88 +++++++++++ libs/core-ui/src/lib/util/StatisticsUtils.ts | 7 +- .../MetricSelector/MetricSelector.tsx | 5 +- libs/localization/src/lib/en.json | 25 ++++ .../Controls/ModelOverview/ModelOverview.tsx | 138 ++++++++++++++++-- .../Controls/ModelOverview/StatsTableUtils.ts | 88 ++++++++++- .../Controls/TabsView/TabsViewProps.ts | 8 - .../ModelAssessmentDashboard.tsx | 4 +- .../ModelAssessmentDashboardProps.ts | 5 + .../utils/getModelTypeFromProps.ts | 3 + 15 files changed, 367 insertions(+), 28 deletions(-) create mode 100644 libs/core-ui/src/lib/util/GenerativeTextStatisticsUtils.ts diff --git a/apps/widget/src/app/ModelAssessment.tsx b/apps/widget/src/app/ModelAssessment.tsx index 96a42db49e..cb2e8233a5 100644 --- a/apps/widget/src/app/ModelAssessment.tsx +++ b/apps/widget/src/app/ModelAssessment.tsx @@ -71,6 +71,20 @@ export class ModelAssessment extends React.Component { abortSignal ); }; + callBack.requestGenerativeTextMetrics = async ( + selectionIndexes: number[][], + generativeTextCache: Map>, + abortSignal: AbortSignal + ): Promise => { + const parameters = [selectionIndexes, generativeTextCache]; + return connectToFlaskServiceWithBackupCall( + this.props.config, + parameters, + "handle_generative_text_json", + "/get_generative_text_metrics", + abortSignal + ); + }; callBack.requestMatrix = async ( data: any[] ): Promise => { diff --git a/apps/widget/src/app/ModelAssessmentUtils.tsx b/apps/widget/src/app/ModelAssessmentUtils.tsx index ab5f705c27..dcfb90caa7 100644 --- a/apps/widget/src/app/ModelAssessmentUtils.tsx +++ b/apps/widget/src/app/ModelAssessmentUtils.tsx @@ -16,6 +16,7 @@ export interface IModelAssessmentProps { export type CallbackType = Pick< IModelAssessmentDashboardProps, | "requestExp" + | "requestGenerativeTextMetrics" | "requestObjectDetectionMetrics" | "requestPredictions" | "requestQuestionAnsweringMetrics" diff --git a/libs/core-ui/src/index.ts b/libs/core-ui/src/index.ts index d14463457c..ce3b04dac1 100644 --- a/libs/core-ui/src/index.ts +++ b/libs/core-ui/src/index.ts @@ -56,6 +56,7 @@ export * from "./lib/util/getFilterBoundsArgs"; export * from "./lib/util/calculateBoxData"; export * from "./lib/util/calculateConfusionMatrixData"; export * from "./lib/util/calculateLineData"; +export * from "./lib/util/GenerativeTextStatisticsUtils"; export * from "./lib/util/MultilabelStatisticsUtils"; export * from "./lib/util/ObjectDetectionStatisticsUtils"; export * from "./lib/util/QuestionAnsweringStatisticsUtils"; diff --git a/libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx b/libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx index d57ce7a3ec..b5e2011765 100644 --- a/libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx +++ b/libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx @@ -140,6 +140,13 @@ export interface IModelAssessmentContext { requestExp?: | ((index: number | number[], abortSignal: AbortSignal) => Promise) | undefined; + requestGenerativeTextMetrics?: + | (( + selectionIndexes: number[][], + generativeTextCache: Map>, + abortSignal: AbortSignal + ) => Promise) + | undefined; requestObjectDetectionMetrics?: | (( selectionIndexes: number[][], diff --git a/libs/core-ui/src/lib/Interfaces/IExplanationContext.ts b/libs/core-ui/src/lib/Interfaces/IExplanationContext.ts index 40ab33a359..0f4cb3b1f9 100644 --- a/libs/core-ui/src/lib/Interfaces/IExplanationContext.ts +++ b/libs/core-ui/src/lib/Interfaces/IExplanationContext.ts @@ -8,6 +8,7 @@ import { JointDataset } from "../util/JointDataset"; export enum ModelTypes { Regression = "regression", Binary = "binary", + GenerativeText = "generativetext", Multiclass = "multiclass", ImageBinary = "imagebinary", ImageMulticlass = "imagemulticlass", diff --git a/libs/core-ui/src/lib/util/GenerativeTextStatisticsUtils.ts b/libs/core-ui/src/lib/util/GenerativeTextStatisticsUtils.ts new file mode 100644 index 0000000000..b36ca6be2e --- /dev/null +++ b/libs/core-ui/src/lib/util/GenerativeTextStatisticsUtils.ts @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { localization } from "@responsible-ai/localization"; + +import { + ILabeledStatistic, + TotalCohortSamples +} from "../Interfaces/IStatistic"; + +import { QuestionAnsweringMetrics } from "./QuestionAnsweringStatisticsUtils"; + +export enum GenerativeTextMetrics { + Coherence = "coherence", + Fluency = "fluency", + Equivalence = "equivalence", + Groundedness = "groundedness", + Relevance = "relevance" +} + +export const generateGenerativeTextStats: ( + selectionIndexes: number[][], + generativeTextCache: Map> +) => ILabeledStatistic[][] = ( + selectionIndexes: number[][], + generativeTextCache: Map> +): ILabeledStatistic[][] => { + return selectionIndexes.map((selectionArray) => { + const count = selectionArray.length; + + const value = generativeTextCache.get(selectionArray.toString()); + const stat: Map = value ? value : new Map(); + + const stats = [ + { + key: TotalCohortSamples, + label: localization.Interpret.Statistics.samples, + stat: count + } + ]; + for (const [key, value] of stat.entries()) { + let label = ""; + switch (key) { + case GenerativeTextMetrics.Coherence: + label = localization.Interpret.Statistics.coherence; + break; + case GenerativeTextMetrics.Fluency: + label = localization.Interpret.Statistics.fluency; + break; + case GenerativeTextMetrics.Equivalence: + label = localization.Interpret.Statistics.equivalence; + break; + case GenerativeTextMetrics.Groundedness: + label = localization.Interpret.Statistics.groundedness; + break; + case GenerativeTextMetrics.Relevance: + label = localization.Interpret.Statistics.relevance; + break; + case QuestionAnsweringMetrics.ExactMatchRatio: + label = localization.Interpret.Statistics.exactMatchRatio; + break; + case QuestionAnsweringMetrics.F1Score: + label = localization.Interpret.Statistics.f1Score; + break; + case QuestionAnsweringMetrics.MeteorScore: + label = localization.Interpret.Statistics.meteorScore; + break; + case QuestionAnsweringMetrics.BleuScore: + label = localization.Interpret.Statistics.bleuScore; + break; + case QuestionAnsweringMetrics.BertScore: + label = localization.Interpret.Statistics.bertScore; + break; + case QuestionAnsweringMetrics.RougeScore: + label = localization.Interpret.Statistics.rougeScore; + break; + default: + break; + } + stats.push({ + key, + label, + stat: value + }); + } + return stats; + }); +}; diff --git a/libs/core-ui/src/lib/util/StatisticsUtils.ts b/libs/core-ui/src/lib/util/StatisticsUtils.ts index cde9e2ca93..42cbc72cd7 100644 --- a/libs/core-ui/src/lib/util/StatisticsUtils.ts +++ b/libs/core-ui/src/lib/util/StatisticsUtils.ts @@ -10,6 +10,7 @@ import { } from "../Interfaces/IStatistic"; import { IsBinary } from "../util/ExplanationUtils"; +import { generateGenerativeTextStats } from "./GenerativeTextStatisticsUtils"; import { JointDataset } from "./JointDataset"; import { ClassificationEnum } from "./JointDatasetUtils"; import { generateMulticlassStats } from "./MulticlassStatisticsUtils"; @@ -156,7 +157,8 @@ export const generateMetrics: ( modelType: ModelTypes, objectDetectionCache?: Map, objectDetectionInputs?: [string, string, number], - questionAnsweringCache?: QuestionAnsweringCacheType + questionAnsweringCache?: QuestionAnsweringCacheType, + generativeTextCache?: Map> ): ILabeledStatistic[][] => { if ( modelType === ModelTypes.ImageMultilabel || @@ -192,6 +194,9 @@ export const generateMetrics: ( objectDetectionInputs ); } + if (modelType === ModelTypes.GenerativeText && generativeTextCache) { + return generateGenerativeTextStats(selectionIndexes, generativeTextCache); + } const outcomes = jointDataset.unwrap(JointDataset.ClassificationError); if (IsBinary(modelType)) { return selectionIndexes.map((selectionArray) => { diff --git a/libs/error-analysis/src/lib/ErrorAnalysisDashboard/Controls/MetricSelector/MetricSelector.tsx b/libs/error-analysis/src/lib/ErrorAnalysisDashboard/Controls/MetricSelector/MetricSelector.tsx index f3514788e3..e244c12357 100644 --- a/libs/error-analysis/src/lib/ErrorAnalysisDashboard/Controls/MetricSelector/MetricSelector.tsx +++ b/libs/error-analysis/src/lib/ErrorAnalysisDashboard/Controls/MetricSelector/MetricSelector.tsx @@ -58,9 +58,12 @@ export class MetricSelector extends React.Component { options.push(this.addDropdownOption(Metrics.AccuracyScore)); } else if ( IsMultilabel(modelType) || - modelType === ModelTypes.ObjectDetection + modelType === ModelTypes.ObjectDetection || + modelType === ModelTypes.QuestionAnswering ) { options.push(this.addDropdownOption(Metrics.ErrorRate)); + } else if (modelType === ModelTypes.GenerativeText) { + options.push(this.addDropdownOption(Metrics.MeanSquaredError)); } return ( void; - requestObjectDetectionMetrics?: ( - selectionIndexes: number[][], - aggregateMethod: string, - className: string, - iouThreshold: number, - objectDetectionCache: Map - ) => Promise; - requestQuestionAnsweringMetrics?: ( - selectionIndexes: number[][], - questionAnsweringCache: Map< - string, - [number, number, number, number, number, number] - > - ) => Promise; } interface IModelOverviewState { @@ -88,6 +75,7 @@ interface IModelOverviewState { featureBasedCohortLabeledStatistics: ILabeledStatistic[][]; featureBasedCohorts: ErrorCohort[]; iouThreshold: number; + generativeTextAbortController: AbortController | undefined; objectDetectionAbortController: AbortController | undefined; questionAnsweringAbortController: AbortController | undefined; } @@ -100,6 +88,7 @@ export class ModelOverview extends React.Component< IModelOverviewState > { public static contextType = ModelAssessmentContext; + public generativeTextCache: Map> = new Map(); public questionAnsweringCache: Map< string, [number, number, number, number, number, number] @@ -125,6 +114,7 @@ export class ModelOverview extends React.Component< featureBasedCohortLabeledStatistics: [], featureBasedCohorts: [], featureConfigurationIsVisible: false, + generativeTextAbortController: undefined, iouThreshold: 70, metricConfigurationIsVisible: false, objectDetectionAbortController: undefined, @@ -184,6 +174,14 @@ export class ModelOverview extends React.Component< QuestionAnsweringMetrics.F1Score, QuestionAnsweringMetrics.BertScore ]; + } else if ( + this.context.dataset.task_type === DatasetTaskType.GenerativeText + ) { + defaultSelectedMetrics = [ + GenerativeTextMetrics.Fluency, + GenerativeTextMetrics.Coherence, + GenerativeTextMetrics.Relevance + ]; } else { // task_type === "regression" defaultSelectedMetrics = [ @@ -633,6 +631,10 @@ export class ModelOverview extends React.Component< this.context.modelMetadata.modelType === ModelTypes.QuestionAnswering ) { this.updateQuestionAnsweringMetrics(selectionIndexes, true); + } else if ( + this.context.modelMetadata.modelType === ModelTypes.GenerativeText + ) { + this.updateGenerativeTextMetrics(selectionIndexes, true); } }; @@ -838,6 +840,108 @@ export class ModelOverview extends React.Component< } } + private updateGenerativeTextMetrics( + selectionIndexes: number[][], + isDatasetCohort: boolean + ): void { + if (this.state.generativeTextAbortController !== undefined) { + this.state.generativeTextAbortController.abort(); + } + const newAbortController = new AbortController(); + this.setState({ generativeTextAbortController: newAbortController }); + if ( + this.context.requestGenerativeTextMetrics && + selectionIndexes.length > 0 + ) { + this.context + .requestGenerativeTextMetrics( + selectionIndexes, + this.generativeTextCache, + newAbortController.signal + ) + .then((result) => { + // Assumption: the lengths of `result` and `selectionIndexes` are the same. + const updatedMetricStats: ILabeledStatistic[][] = []; + + for (const [cohortIndex, metrics] of result.entries()) { + const count = selectionIndexes[cohortIndex].length; + const metricsMap = new Map(Object.entries(metrics)); + + if ( + !this.generativeTextCache.has( + selectionIndexes[cohortIndex].toString() + ) + ) { + this.generativeTextCache.set( + selectionIndexes[cohortIndex].toString(), + metricsMap + ); + } + + const updatedCohortMetricStats = [ + { + key: TotalCohortSamples, + label: localization.Interpret.Statistics.samples, + stat: count + } + ]; + + for (const [key, value] of metricsMap.entries()) { + let label = ""; + switch (key) { + case GenerativeTextMetrics.Coherence: + label = localization.Interpret.Statistics.coherence; + break; + case GenerativeTextMetrics.Fluency: + label = localization.Interpret.Statistics.fluency; + break; + case GenerativeTextMetrics.Equivalence: + label = localization.Interpret.Statistics.equivalence; + break; + case GenerativeTextMetrics.Groundedness: + label = localization.Interpret.Statistics.groundedness; + break; + case GenerativeTextMetrics.Relevance: + label = localization.Interpret.Statistics.relevance; + break; + case QuestionAnsweringMetrics.ExactMatchRatio: + label = localization.Interpret.Statistics.exactMatchRatio; + break; + case QuestionAnsweringMetrics.F1Score: + label = localization.Interpret.Statistics.f1Score; + break; + case QuestionAnsweringMetrics.MeteorScore: + label = localization.Interpret.Statistics.meteorScore; + break; + case QuestionAnsweringMetrics.BleuScore: + label = localization.Interpret.Statistics.bleuScore; + break; + case QuestionAnsweringMetrics.BertScore: + label = localization.Interpret.Statistics.bertScore; + break; + case QuestionAnsweringMetrics.RougeScore: + label = localization.Interpret.Statistics.rougeScore; + break; + default: + break; + } + updatedCohortMetricStats.push({ + key, + label, + stat: value + }); + } + + updatedMetricStats.push(updatedCohortMetricStats); + } + + isDatasetCohort + ? this.updateDatasetCohortState(updatedMetricStats) + : this.updateFeatureCohortState(updatedMetricStats); + }); + } + } + private updateDatasetCohortState( cohortMetricStats: ILabeledStatistic[][] ): void { @@ -884,6 +988,10 @@ export class ModelOverview extends React.Component< this.context.modelMetadata.modelType === ModelTypes.QuestionAnswering ) { this.updateQuestionAnsweringMetrics(selectionIndexes, false); + } else if ( + this.context.modelMetadata.modelType === ModelTypes.GenerativeText + ) { + this.updateGenerativeTextMetrics(selectionIndexes, false); } }; @@ -998,6 +1106,8 @@ export class ModelOverview extends React.Component< abortController = this.state.objectDetectionAbortController; } else if (taskType === DatasetTaskType.QuestionAnswering) { abortController = this.state.questionAnsweringAbortController; + } else if (taskType === DatasetTaskType.GenerativeText) { + abortController = this.state.generativeTextAbortController; } if (abortController !== undefined) { abortController.abort(); diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts index ee2af8e141..8ee368cc5e 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts @@ -6,6 +6,7 @@ import { BinaryClassificationMetrics, DatasetTaskType, ErrorCohort, + GenerativeTextMetrics, HighchartsNull, ILabeledStatistic, ModelTypes, @@ -154,7 +155,8 @@ export function generateCohortsStatsTable( const colorConfig = useTexturedBackgroundForNaN && modelType !== ModelTypes.ObjectDetection && - modelType !== ModelTypes.QuestionAnswering + modelType !== ModelTypes.QuestionAnswering && + modelType !== ModelTypes.GenerativeText ? { color: { pattern: { @@ -458,6 +460,90 @@ export function getSelectableMetrics( text: localization.ModelAssessment.ModelOverview.metrics.bleuScore.name }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.bertScore + .description, + key: QuestionAnsweringMetrics.BertScore, + text: localization.ModelAssessment.ModelOverview.metrics.bertScore.name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.rougeScore + .description, + key: QuestionAnsweringMetrics.RougeScore, + text: localization.ModelAssessment.ModelOverview.metrics.rougeScore.name + } + ); + } else if (taskType === DatasetTaskType.GenerativeText) { + selectableMetrics.push( + { + description: + localization.ModelAssessment.ModelOverview.metrics.coherence + .description, + key: GenerativeTextMetrics.Coherence, + text: localization.ModelAssessment.ModelOverview.metrics.coherence.name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.fluency + .description, + key: GenerativeTextMetrics.Fluency, + text: localization.ModelAssessment.ModelOverview.metrics.fluency.name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.equivalence + .description, + key: GenerativeTextMetrics.Equivalence, + text: localization.ModelAssessment.ModelOverview.metrics.equivalence + .name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.groundedness + .description, + key: GenerativeTextMetrics.Groundedness, + text: localization.ModelAssessment.ModelOverview.metrics.groundedness + .name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.relevance + .description, + key: GenerativeTextMetrics.Relevance, + text: localization.ModelAssessment.ModelOverview.metrics.relevance.name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.exactMatchRatio + .description, + key: QuestionAnsweringMetrics.ExactMatchRatio, + text: localization.ModelAssessment.ModelOverview.metrics.exactMatchRatio + .name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.meteorScore + .description, + key: QuestionAnsweringMetrics.MeteorScore, + text: localization.ModelAssessment.ModelOverview.metrics.meteorScore + .name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.f1Score + .description, + key: QuestionAnsweringMetrics.F1Score, + text: localization.ModelAssessment.ModelOverview.metrics.f1Score.name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.bleuScore + .description, + key: QuestionAnsweringMetrics.BleuScore, + text: localization.ModelAssessment.ModelOverview.metrics.bleuScore.name + }, { description: localization.ModelAssessment.ModelOverview.metrics.bertScore diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts index 417755fdc6..a3a836f6fd 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts @@ -57,14 +57,6 @@ export interface ITabsViewProps { request: any[], abortSignal: AbortSignal ) => Promise; - requestQuestionAnsweringMetrics?: ( - selectionIndexes: number[][], - questionAnsweringCache: Map< - string, - [number, number, number, number, number, number] - >, - abortSignal: AbortSignal - ) => Promise; requestDebugML?: ( request: any[], abortSignal: AbortSignal diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboard.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboard.tsx index 0db19920d6..bda8052184 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboard.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboard.tsx @@ -89,6 +89,7 @@ export class ModelAssessmentDashboard extends CohortBasedComponent< this.props.requestDatasetAnalysisBoxChart, requestExp: this.props.requestExp, requestForecast: this.props.requestForecast, + requestGenerativeTextMetrics: this.props.requestGenerativeTextMetrics, requestGlobalCausalEffects: this.props.requestGlobalCausalEffects, requestGlobalCausalPolicy: this.props.requestGlobalCausalPolicy, requestGlobalExplanations: this.props.requestGlobalExplanations, @@ -143,9 +144,6 @@ export class ModelAssessmentDashboard extends CohortBasedComponent< this.props.requestObjectDetectionMetrics } requestPredictions={this.props.requestPredictions} - requestQuestionAnsweringMetrics={ - this.props.requestQuestionAnsweringMetrics - } requestDebugML={this.props.requestDebugML} requestImportances={this.props.requestImportances} requestMatrix={this.props.requestMatrix} diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts index 61651ca940..77c94f4a66 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts @@ -115,6 +115,11 @@ export interface IModelAssessmentDashboardProps index: number | number[], abortSignal: AbortSignal ) => Promise; + requestGenerativeTextMetrics?: ( + selectionIndexes: number[][], + generativeTextCache: Map>, + abortSignal: AbortSignal + ) => Promise; requestObjectDetectionMetrics?: ( selectionIndexes: number[][], aggregateMethod: string, diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/getModelTypeFromProps.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/getModelTypeFromProps.ts index e57a8f51a6..96646f77db 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/getModelTypeFromProps.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/getModelTypeFromProps.ts @@ -53,5 +53,8 @@ export function getModelTypeFromProps( if (taskType === DatasetTaskType.QuestionAnswering) { return ModelTypes.QuestionAnswering; } + if (taskType === DatasetTaskType.GenerativeText) { + return ModelTypes.GenerativeText; + } return modelType; }