diff --git a/js/src/lib/components/InferenceWidget/shared/helpers.ts b/js/src/lib/components/InferenceWidget/shared/helpers.ts index 62e8824e5..e89206258 100644 --- a/js/src/lib/components/InferenceWidget/shared/helpers.ts +++ b/js/src/lib/components/InferenceWidget/shared/helpers.ts @@ -1,5 +1,6 @@ import type { ModelData } from "../../../interfaces/Types"; import { randomItem, parseJSON } from "../../../utils/ViewUtils"; +import type { WidgetExample } from "./WidgetExample"; import type { ModelLoadInfo, TableData } from "./types"; export function getSearchParams(keys: string[]): string[] { @@ -19,6 +20,14 @@ export function getDemoInputs(model: ModelData, keys: (number | string)[]): any[ }); } +export function getWidgetExample( + model: ModelData, + validateExample: (sample: WidgetExample) => sample is TWidgetExample +): TWidgetExample | undefined { + const sample = model.widgetData?.length ? randomItem(model.widgetData) : undefined; + return sample && validateExample(sample) ? sample : undefined; +} + // Update current url search params, keeping existing keys intact. export function updateUrl(obj: Record): void { if (!window) { diff --git a/js/src/lib/components/InferenceWidget/widgets/FillMaskWidget/FillMaskWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/FillMaskWidget/FillMaskWidget.svelte index 08c16186d..24d767881 100644 --- a/js/src/lib/components/InferenceWidget/widgets/FillMaskWidget/FillMaskWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/FillMaskWidget/FillMaskWidget.svelte @@ -8,7 +8,13 @@ import WidgetTextarea from "../../shared/WidgetTextarea/WidgetTextarea.svelte"; import WidgetSubmitBtn from "../../shared/WidgetSubmitBtn/WidgetSubmitBtn.svelte"; import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapper.svelte"; - import { addInferenceParameters, getDemoInputs, getResponse, getSearchParams, updateUrl } from "../../shared/helpers"; + import { + addInferenceParameters, + getResponse, + getSearchParams, + getWidgetExample, + updateUrl, + } from "../../shared/helpers"; import { isValidOutputLabels } from "../../shared/outputValidation"; import { isTextInput } from "../../shared/inputValidation"; @@ -38,9 +44,8 @@ setTextAreaValue(textParam); getOutput(); } else { - const [demoText] = getDemoInputs(model, ["text"]); - /// TODO(get rid of useless getDemoInputs) - setTextAreaValue(demoText ?? ""); + const sample = getWidgetExample(model, validateExample); + setTextAreaValue(sample?.text ?? ""); if (text && callApiOnMount) { getOutput({ isOnLoadCall: true }); }