Skip to content

Commit

Permalink
[Widgets] Refactor examples running (#1023)
Browse files Browse the repository at this point in the history
* Rm `previewInputSample` by reusing `applyInputSample`

* Rm unneeded `getDemoInputs` function

* better naming

* format

* correct typing

* add missing await

* Update js/src/lib/components/InferenceWidget/shared/helpers.ts

Co-authored-by: Simon Brandeis <[email protected]>

* format

* use `opts` syntax

* run onMount widget example inside WidgetWrapper

* all widgets run onMount example inside WidgetWrapper

* Rm `previewInputSample` by reusing `applyInputSample`

* Rm unneeded `getDemoInputs` function

* better naming

* format

* correct typing

* add missing await

* Update js/src/lib/components/InferenceWidget/shared/helpers.ts

Co-authored-by: Simon Brandeis <[email protected]>

* format

* use `opts` syntax

* run onMount widget example inside WidgetWrapper

* all widgets run onMount example inside WidgetWrapper

* fix logic

* stronger typing

* format

* more refactor

---------

Co-authored-by: Simon Brandeis <[email protected]>
  • Loading branch information
mishig25 and SBrandeis authored Oct 30, 2023
1 parent a7c8fec commit cd15b19
Show file tree
Hide file tree
Showing 29 changed files with 362 additions and 724 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ export type WidgetExampleAssetAndZeroShotInput<TOutput = WidgetExampleOutput> =
WidgetExampleZeroShotTextInput<TOutput>;

export interface WidgetExampleStructuredDataInput<TOutput = WidgetExampleOutput> extends WidgetExampleBase<TOutput> {
structuredData: TableData;
structured_data: TableData;
}

export interface WidgetExampleTableDataInput<TOutput = WidgetExampleOutput> extends WidgetExampleBase<TOutput> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
<script lang="ts">
import type { ExampleRunOpts } from "../types";
import type { WidgetExample } from "../WidgetExample";
type TWidgetExample = $$Generic<WidgetExample>;
Expand All @@ -10,8 +11,7 @@
export let classNames = "";
export let isLoading = false;
export let inputSamples: TWidgetExample[];
export let applyInputSample: (sample: TWidgetExample) => void;
export let previewInputSample: (sample: TWidgetExample) => void;
export let applyInputSample: (sample: TWidgetExample, opts?: ExampleRunOpts) => void;
let containerEl: HTMLElement;
let isOptionsVisible = false;
Expand All @@ -32,7 +32,7 @@
function _previewInputSample(idx: number) {
const sample = inputSamples[idx];
previewInputSample(sample);
applyInputSample(sample, { isPreview: true });
}
function toggleOptionsVisibility() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<script lang="ts">
import type { WidgetProps, ModelLoadInfo } from "../types";
import type { WidgetProps, ModelLoadInfo, ExampleRunOpts } from "../types";
import type { WidgetExample } from "../WidgetExample";
import type { QueryParam } from "../../shared/helpers";
type TWidgetExample = $$Generic<WidgetExample>;
Expand All @@ -13,10 +14,11 @@
import WidgetHeader from "../WidgetHeader/WidgetHeader.svelte";
import WidgetInfo from "../WidgetInfo/WidgetInfo.svelte";
import WidgetModelLoading from "../WidgetModelLoading/WidgetModelLoading.svelte";
import { getModelLoadInfo } from "../../shared/helpers";
import { getModelLoadInfo, getQueryParamVal, getWidgetExample } from "../../shared/helpers";
import { modelLoadStates } from "../../stores";
export let apiUrl: string;
export let callApiOnMount: WidgetProps["callApiOnMount"];
export let computeTime: string;
export let error: string;
export let isLoading = false;
Expand All @@ -28,9 +30,9 @@
};
export let noTitle = false;
export let outputJson: string;
export let applyInputSample: (sample: TWidgetExample) => void = () => {};
export let previewInputSample: (sample: TWidgetExample) => void = () => {};
export let applyInputSample: (sample: TWidgetExample, opts?: ExampleRunOpts) => void = () => {};
export let validateExample: (sample: WidgetExample) => sample is TWidgetExample;
export let exampleQueryParams: QueryParam[] = [];
let isMaximized = false;
let modelLoadInfo: ModelLoadInfo | undefined = undefined;
Expand Down Expand Up @@ -64,6 +66,24 @@
(async () => {
modelLoadInfo = await getModelLoadInfo(apiUrl, model.id, includeCredentials);
$modelLoadStates[model.id] = modelLoadInfo;
const exampleFromQueryParams = {} as TWidgetExample;
for (const key of exampleQueryParams) {
const val = getQueryParamVal(key);
if (val) {
exampleFromQueryParams[key] = val;
}
}
if (Object.keys(exampleFromQueryParams).length) {
// run widget example from query params
applyInputSample(exampleFromQueryParams);
} else {
// run random widget example
const example = getWidgetExample<TWidgetExample>(model, validateExample);
if (callApiOnMount && example) {
applyInputSample(example, { inferenceOpts: { isOnLoadCall: true } });
}
}
})();
});
Expand Down Expand Up @@ -107,7 +127,6 @@
{isLoading}
inputSamples={selectedInputSamples?.inputSamples ?? []}
{applyInputSample}
{previewInputSample}
/>
</div>
{/if}
Expand Down
40 changes: 27 additions & 13 deletions js/src/lib/components/InferenceWidget/shared/helpers.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,40 @@
import type { ModelData } from "../../../interfaces/Types";
import { randomItem, parseJSON } from "../../../utils/ViewUtils";
import type { WidgetExample } from "./WidgetExample";
import type { ModelLoadInfo, TableData } from "./types";

export function getSearchParams(keys: string[]): string[] {
type KeysOfUnion<T> = T extends any ? keyof T : never;
export type QueryParam = KeysOfUnion<WidgetExample>;
type QueryParamVal = string | null | boolean | (string | number)[][];
const KEYS_TEXT: QueryParam[] = ["text", "context", "candidate_labels"];
const KEYS_TABLE: QueryParam[] = ["table", "structured_data"];

export function getQueryParamVal(key: QueryParam): QueryParamVal {
const searchParams = new URL(window.location.href).searchParams;
return keys.map(key => {
const value = searchParams.get(key);
return value || "";
});
const value = searchParams.get(key);
if (KEYS_TEXT.includes(key)) {
return value;
} else if (KEYS_TABLE.includes(key)) {
const table = convertDataToTable((parseJSON(value) as TableData) ?? {});
return table;
} else if (key === "multi_class") {
return value === "true";
}
return value;
}

export function getDemoInputs(model: ModelData, keys: (number | string)[]): any[] {
const widgetData = Array.isArray(model.widgetData) ? model.widgetData : [];
const randomEntry = (randomItem(widgetData) ?? {}) as any;
return keys.map(key => {
const value = randomEntry[key] ? randomEntry[key] : null;
return value ? randomEntry[key] : null;
});
export function getWidgetExample<TWidgetExample extends WidgetExample>(
model: ModelData,
validateExample: (sample: WidgetExample) => sample is TWidgetExample
): TWidgetExample | undefined {
const validExamples = model.widgetData?.filter(
(sample): sample is TWidgetExample => sample && validateExample(sample)
);
return validExamples?.length ? randomItem(validExamples) : undefined;
}

// Update current url search params, keeping existing keys intact.
export function updateUrl(obj: Record<string, string | undefined>): void {
export function updateUrl(obj: Partial<Record<QueryParam, string | undefined>>): void {
if (!window) {
return;
}
Expand Down
11 changes: 11 additions & 0 deletions js/src/lib/components/InferenceWidget/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ export interface WidgetProps {
isLoggedIn?: boolean;
}

export interface InferenceRunFlags {
withModelLoading?: boolean;
isOnLoadCall?: boolean;
useCache?: boolean;
}

export interface ExampleRunOpts {
isPreview?: boolean;
inferenceOpts?: InferenceRunFlags;
}

export type LoadState = "Loadable" | "Loaded" | "TooBig" | "error";

export type ComputeType = "cpu" | "gpu";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
<script lang="ts">
import type { WidgetProps } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetExample, WidgetExampleAssetInput, WidgetExampleOutputLabels } from "../../shared/WidgetExample";
import { onMount } from "svelte";
import WidgetAudioTrack from "../../shared/WidgetAudioTrack/WidgetAudioTrack.svelte";
import WidgetFileInput from "../../shared/WidgetFileInput/WidgetFileInput.svelte";
import WidgetOutputChart from "../../shared/WidgetOutputChart/WidgetOutputChart.svelte";
import WidgetRecorder from "../../shared/WidgetRecorder/WidgetRecorder.svelte";
import WidgetSubmitBtn from "../../shared/WidgetSubmitBtn/WidgetSubmitBtn.svelte";
import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapper.svelte";
import { callInferenceApi, getBlobFromUrl, getDemoInputs } from "../../shared/helpers";
import { callInferenceApi, getBlobFromUrl } from "../../shared/helpers";
import { isValidOutputLabels } from "../../shared/outputValidation";
import { isAssetInput } from "../../shared/inputValidation";
Expand Down Expand Up @@ -61,7 +59,7 @@
}
}
async function getOutput({ withModelLoading = false, isOnLoadCall = false } = {}) {
async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}) {
if (!file && !selectedSampleUrl) {
error = "You must select or record an audio file";
output = [];
Expand Down Expand Up @@ -122,42 +120,32 @@
throw new TypeError("Invalid output: output must be of type Array<label: string, score:number>");
}
function applyInputSample(sample: WidgetExampleAssetInput<WidgetExampleOutputLabels>) {
file = null;
function applyInputSample(sample: WidgetExampleAssetInput<WidgetExampleOutputLabels>, opts: ExampleRunOpts = {}) {
filename = sample.example_title!;
fileUrl = sample.src;
selectedSampleUrl = sample.src;
getOutput();
}
function previewInputSample(sample: WidgetExampleAssetInput<WidgetExampleOutputLabels>) {
filename = sample.example_title!;
fileUrl = sample.src;
if (isValidOutputLabels(sample.output)) {
output = sample.output;
outputJson = "";
} else {
output = [];
outputJson = "";
if (opts.isPreview) {
if (isValidOutputLabels(sample.output)) {
output = sample.output;
outputJson = "";
} else {
output = [];
outputJson = "";
}
return;
}
file = null;
selectedSampleUrl = sample.src;
getOutput(opts.inferenceOpts);
}
function validateExample(sample: WidgetExample): sample is WidgetExampleAssetInput<WidgetExampleOutputLabels> {
return isAssetInput(sample) && (!sample.output || isValidOutputLabels(sample.output));
}
onMount(() => {
const [exampleTitle, src] = getDemoInputs(model, ["example_title", "src"]);
if (callApiOnMount && src) {
filename = exampleTitle ?? "";
fileUrl = src;
selectedSampleUrl = src;
getOutput({ isOnLoadCall: true });
}
});
</script>

<WidgetWrapper
{callApiOnMount}
{apiUrl}
{includeCredentials}
{applyInputSample}
Expand All @@ -168,7 +156,6 @@
{modelLoading}
{noTitle}
{outputJson}
{previewInputSample}
{validateExample}
>
<svelte:fragment slot="top">
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
<script lang="ts">
import type { WidgetProps } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetExampleAssetInput } from "../../shared/WidgetExample";
import { onMount } from "svelte";
import WidgetAudioTrack from "../../shared/WidgetAudioTrack/WidgetAudioTrack.svelte";
import WidgetFileInput from "../../shared/WidgetFileInput/WidgetFileInput.svelte";
import WidgetRecorder from "../../shared/WidgetRecorder/WidgetRecorder.svelte";
import WidgetSubmitBtn from "../../shared/WidgetSubmitBtn/WidgetSubmitBtn.svelte";
import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapper.svelte";
import { callInferenceApi, getBlobFromUrl, getDemoInputs } from "../../shared/helpers";
import { callInferenceApi, getBlobFromUrl } from "../../shared/helpers";
import { isAssetInput } from "../../shared/inputValidation";
export let apiToken: WidgetProps["apiToken"];
Expand Down Expand Up @@ -64,7 +62,7 @@
}
}
async function getOutput({ withModelLoading = false, isOnLoadCall = false } = {}) {
async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}) {
if (!file && !selectedSampleUrl) {
error = "You must select or record an audio file";
return;
Expand Down Expand Up @@ -126,33 +124,22 @@
throw new TypeError("Invalid output: output must be of type Array<blob:string, label:string, content-type:string>");
}
function applyInputSample(sample: WidgetExampleAssetInput) {
file = null;
function applyInputSample(sample: WidgetExampleAssetInput, opts: ExampleRunOpts = {}) {
filename = sample.example_title ?? "";
fileUrl = sample.src;
if (opts.isPreview) {
output = [];
outputJson = "";
return;
}
file = null;
selectedSampleUrl = sample.src;
getOutput();
}
function previewInputSample(sample: WidgetExampleAssetInput) {
filename = sample.example_title ?? "";
fileUrl = sample.src;
output = [];
outputJson = "";
getOutput(opts.inferenceOpts);
}
onMount(() => {
const [exampleTitle, src] = getDemoInputs(model, ["example_title", "src"]);
if (callApiOnMount && src) {
filename = exampleTitle ?? "";
fileUrl = src;
selectedSampleUrl = src;
getOutput({ isOnLoadCall: true });
}
});
</script>

<WidgetWrapper
{callApiOnMount}
{apiUrl}
{includeCredentials}
{applyInputSample}
Expand All @@ -163,7 +150,6 @@
{modelLoading}
{noTitle}
{outputJson}
{previewInputSample}
validateExample={isAssetInput}
>
<svelte:fragment slot="top">
Expand Down
Loading

0 comments on commit cd15b19

Please sign in to comment.