Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Widgets] Refactor examples running #1023

Merged
merged 27 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
667710d
Rm `previewInputSample` by reusing `applyInputSample`
mishig25 Oct 13, 2023
85aac2e
Rm unneeded `getDemoInputs` function
mishig25 Oct 13, 2023
c56c64f
better naming
mishig25 Oct 16, 2023
aa2396e
format
mishig25 Oct 16, 2023
d2374e4
correct typing
mishig25 Oct 16, 2023
d5325f9
add missing await
mishig25 Oct 16, 2023
2730c9c
Update js/src/lib/components/InferenceWidget/shared/helpers.ts
mishig25 Oct 16, 2023
3ba2298
format
mishig25 Oct 16, 2023
7858186
use `opts` syntax
mishig25 Oct 16, 2023
fd40b25
run onMount widget example inside WidgetWrapper
mishig25 Oct 26, 2023
756223d
all widgets run onMount example inside WidgetWrapper
mishig25 Oct 27, 2023
580c830
Rm `previewInputSample` by reusing `applyInputSample`
mishig25 Oct 13, 2023
f89936c
Rm unneeded `getDemoInputs` function
mishig25 Oct 13, 2023
0654367
better naming
mishig25 Oct 16, 2023
94249c4
format
mishig25 Oct 16, 2023
353f172
correct typing
mishig25 Oct 16, 2023
23a2fe0
add missing await
mishig25 Oct 16, 2023
c36fcdc
Update js/src/lib/components/InferenceWidget/shared/helpers.ts
mishig25 Oct 16, 2023
06420aa
format
mishig25 Oct 16, 2023
44eb4f1
use `opts` syntax
mishig25 Oct 16, 2023
ada2b70
run onMount widget example inside WidgetWrapper
mishig25 Oct 26, 2023
37f2ccd
all widgets run onMount example inside WidgetWrapper
mishig25 Oct 27, 2023
8446c3e
Merge branch 'rebase' into refactor_examples_running
mishig25 Oct 27, 2023
7ebca52
fix logic
mishig25 Oct 30, 2023
c89086b
stronger typing
mishig25 Oct 30, 2023
4d03786
format
mishig25 Oct 30, 2023
9ea310a
more refactor
mishig25 Oct 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ export type WidgetExampleAssetAndZeroShotInput<TOutput = WidgetExampleOutput> =
WidgetExampleZeroShotTextInput<TOutput>;

export interface WidgetExampleStructuredDataInput<TOutput = WidgetExampleOutput> extends WidgetExampleBase<TOutput> {
structuredData: TableData;
structured_data: TableData;
}

export interface WidgetExampleTableDataInput<TOutput = WidgetExampleOutput> extends WidgetExampleBase<TOutput> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
<script lang="ts">
import type { ExampleRunOpts } from "../types";
import type { WidgetExample } from "../WidgetExample";

type TWidgetExample = $$Generic<WidgetExample>;
Expand All @@ -10,8 +11,7 @@
export let classNames = "";
export let isLoading = false;
export let inputSamples: TWidgetExample[];
export let applyInputSample: (sample: TWidgetExample) => void;
export let previewInputSample: (sample: TWidgetExample) => void;
export let applyInputSample: (sample: TWidgetExample, opts?: ExampleRunOpts) => void;

let containerEl: HTMLElement;
let isOptionsVisible = false;
Expand All @@ -32,7 +32,7 @@

function _previewInputSample(idx: number) {
const sample = inputSamples[idx];
previewInputSample(sample);
applyInputSample(sample, { isPreview: true });
}

function toggleOptionsVisibility() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<script lang="ts">
import type { WidgetProps, ModelLoadInfo } from "../types";
import type { WidgetProps, ModelLoadInfo, ExampleRunOpts } from "../types";
import type { WidgetExample } from "../WidgetExample";
import type { QueryParam } from "../../shared/helpers";

type TWidgetExample = $$Generic<WidgetExample>;

Expand All @@ -13,10 +14,11 @@
import WidgetHeader from "../WidgetHeader/WidgetHeader.svelte";
import WidgetInfo from "../WidgetInfo/WidgetInfo.svelte";
import WidgetModelLoading from "../WidgetModelLoading/WidgetModelLoading.svelte";
import { getModelLoadInfo } from "../../shared/helpers";
import { getModelLoadInfo, getQueryParamVal, getWidgetExample } from "../../shared/helpers";
import { modelLoadStates } from "../../stores";

export let apiUrl: string;
export let callApiOnMount: WidgetProps["callApiOnMount"];
export let computeTime: string;
export let error: string;
export let isLoading = false;
Expand All @@ -28,9 +30,9 @@
};
export let noTitle = false;
export let outputJson: string;
export let applyInputSample: (sample: TWidgetExample) => void = () => {};
export let previewInputSample: (sample: TWidgetExample) => void = () => {};
export let applyInputSample: (sample: TWidgetExample, opts?: ExampleRunOpts) => void = () => {};
export let validateExample: (sample: WidgetExample) => sample is TWidgetExample;
export let exampleQueryParams: QueryParam[] = [];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be typed further?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this for example (requires some ugly casts in onMount but I think I'm fine with it as it makes it almost impossible to pass wrong query params)

Suggested change
export let exampleQueryParams: QueryParam[] = [];
export let exampleQueryParams: (keyof TWidgetExample)[] = [];

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, stronger typing is possible c89086b


let isMaximized = false;
let modelLoadInfo: ModelLoadInfo | undefined = undefined;
Expand Down Expand Up @@ -64,6 +66,24 @@
(async () => {
modelLoadInfo = await getModelLoadInfo(apiUrl, model.id, includeCredentials);
$modelLoadStates[model.id] = modelLoadInfo;

const exampleFromQueryParams = {} as TWidgetExample;
for (const key of exampleQueryParams) {
const val = getQueryParamVal(key);
if(val){
exampleFromQueryParams[key] = val;
}
}
if (Object.keys(exampleFromQueryParams).length) {
// run widget example from query params
applyInputSample(exampleFromQueryParams);
} else {
// run random widget example
const example = getWidgetExample<TWidgetExample>(model, validateExample);
if (callApiOnMount && example) {
mishig25 marked this conversation as resolved.
Show resolved Hide resolved
applyInputSample(example, { inferenceOpts: { isOnLoadCall: true } });
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure but isOnLoadCall should prioritize the example output when specified, no?
Probably for a subsequent PR

})();
});

Expand Down Expand Up @@ -107,7 +127,6 @@
{isLoading}
inputSamples={selectedInputSamples?.inputSamples ?? []}
{applyInputSample}
{previewInputSample}
/>
</div>
{/if}
Expand Down
37 changes: 24 additions & 13 deletions js/src/lib/components/InferenceWidget/shared/helpers.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,37 @@
import type { ModelData } from "../../../interfaces/Types";
import { randomItem, parseJSON } from "../../../utils/ViewUtils";
import type { WidgetExample } from "./WidgetExample";
import type { ModelLoadInfo, TableData } from "./types";

export function getSearchParams(keys: string[]): string[] {
type KeysOfUnion<T> = T extends any ? keyof T : never;
export type QueryParam = KeysOfUnion<WidgetExample>;
type QueryParamVal = string | null | boolean | (string | number)[][];
export function getQueryParamVal(key: QueryParam): QueryParamVal {
const searchParams = new URL(window.location.href).searchParams;
return keys.map(key => {
const value = searchParams.get(key);
return value || "";
});
const value = searchParams.get(key);
if (["text", "context", "question", "query", "candidate_labels"].includes(key)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will not work on existing URLs of ZeroShotClassification widget with URL query candidateLabels rather than candidate_labels. However, ZeroShotClassification widgets are not used widely. Therefore, I think this breaking change is fine

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i agree for candidate_labels (it's used in the spec for widget examples) but for structuredData it's used that way in the spec no?

https://huggingface.co/docs/hub/models-widgets-examples#structured-data-classification

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdyt #1064?
I can either merge #1064 or revert the changes back to using strucutredData. I will let you decide

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my comment there

return value;
} else if (["table", "structured_data"].includes(key)) {
Copy link
Collaborator Author

@mishig25 mishig25 Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the comment above but for structuredData vs structured_data

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const table = convertDataToTable((parseJSON(value) as TableData) ?? {});
return table;
} else if (key === "multi_class") {
return value === "true";
}
return value;
}

export function getDemoInputs(model: ModelData, keys: (number | string)[]): any[] {
const widgetData = Array.isArray(model.widgetData) ? model.widgetData : [];
const randomEntry = (randomItem(widgetData) ?? {}) as any;
return keys.map(key => {
const value = randomEntry[key] ? randomEntry[key] : null;
return value ? randomEntry[key] : null;
});
export function getWidgetExample<TWidgetExample extends WidgetExample>(
model: ModelData,
validateExample: (sample: WidgetExample) => sample is TWidgetExample
): TWidgetExample | undefined {
const validExamples = model.widgetData?.filter(
(sample): sample is TWidgetExample => sample && validateExample(sample)
);
return validExamples?.length ? randomItem(validExamples) : undefined;
}

// Update current url search params, keeping existing keys intact.
export function updateUrl(obj: Record<string, string | undefined>): void {
export function updateUrl(obj: Partial<Record<QueryParam, string | undefined>>): void {
if (!window) {
return;
}
Expand Down
11 changes: 11 additions & 0 deletions js/src/lib/components/InferenceWidget/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ export interface WidgetProps {
isLoggedIn?: boolean;
}

export interface InferenceRunFlags {
withModelLoading?: boolean;
isOnLoadCall?: boolean;
useCache?: boolean;
}

export interface ExampleRunOpts {
isPreview?: boolean;
inferenceOpts?: InferenceRunFlags;
}

export type LoadState = "Loadable" | "Loaded" | "TooBig" | "error";

export type ComputeType = "cpu" | "gpu";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
<script lang="ts">
import type { WidgetProps } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetExample, WidgetExampleAssetInput, WidgetExampleOutputLabels } from "../../shared/WidgetExample";

import { onMount } from "svelte";

import WidgetAudioTrack from "../../shared/WidgetAudioTrack/WidgetAudioTrack.svelte";
import WidgetFileInput from "../../shared/WidgetFileInput/WidgetFileInput.svelte";
import WidgetOutputChart from "../../shared/WidgetOutputChart/WidgetOutputChart.svelte";
import WidgetRecorder from "../../shared/WidgetRecorder/WidgetRecorder.svelte";
import WidgetSubmitBtn from "../../shared/WidgetSubmitBtn/WidgetSubmitBtn.svelte";
import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapper.svelte";
import { callInferenceApi, getBlobFromUrl, getDemoInputs } from "../../shared/helpers";
import { callInferenceApi, getBlobFromUrl } from "../../shared/helpers";
import { isValidOutputLabels } from "../../shared/outputValidation";
import { isAssetInput } from "../../shared/inputValidation";

Expand Down Expand Up @@ -61,7 +59,7 @@
}
}

async function getOutput({ withModelLoading = false, isOnLoadCall = false } = {}) {
async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}) {
if (!file && !selectedSampleUrl) {
error = "You must select or record an audio file";
output = [];
Expand Down Expand Up @@ -122,42 +120,32 @@
throw new TypeError("Invalid output: output must be of type Array<label: string, score:number>");
}

function applyInputSample(sample: WidgetExampleAssetInput<WidgetExampleOutputLabels>) {
file = null;
function applyInputSample(sample: WidgetExampleAssetInput<WidgetExampleOutputLabels>, opts: ExampleRunOpts = {}) {
filename = sample.example_title!;
fileUrl = sample.src;
selectedSampleUrl = sample.src;
getOutput();
}

function previewInputSample(sample: WidgetExampleAssetInput<WidgetExampleOutputLabels>) {
filename = sample.example_title!;
fileUrl = sample.src;
if (isValidOutputLabels(sample.output)) {
output = sample.output;
outputJson = "";
} else {
output = [];
outputJson = "";
if (opts.isPreview) {
if (isValidOutputLabels(sample.output)) {
output = sample.output;
outputJson = "";
} else {
output = [];
outputJson = "";
}
return;
}
file = null;
selectedSampleUrl = sample.src;
getOutput(opts.inferenceOpts);
}

function validateExample(sample: WidgetExample): sample is WidgetExampleAssetInput<WidgetExampleOutputLabels> {
return isAssetInput(sample) && (!sample.output || isValidOutputLabels(sample.output));
}

onMount(() => {
const [exampleTitle, src] = getDemoInputs(model, ["example_title", "src"]);
if (callApiOnMount && src) {
filename = exampleTitle ?? "";
fileUrl = src;
selectedSampleUrl = src;
getOutput({ isOnLoadCall: true });
}
});
</script>

<WidgetWrapper
{callApiOnMount}
{apiUrl}
{includeCredentials}
{applyInputSample}
Expand All @@ -168,7 +156,6 @@
{modelLoading}
{noTitle}
{outputJson}
{previewInputSample}
{validateExample}
>
<svelte:fragment slot="top">
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
<script lang="ts">
import type { WidgetProps } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetExampleAssetInput } from "../../shared/WidgetExample";

import { onMount } from "svelte";

import WidgetAudioTrack from "../../shared/WidgetAudioTrack/WidgetAudioTrack.svelte";
import WidgetFileInput from "../../shared/WidgetFileInput/WidgetFileInput.svelte";
import WidgetRecorder from "../../shared/WidgetRecorder/WidgetRecorder.svelte";
import WidgetSubmitBtn from "../../shared/WidgetSubmitBtn/WidgetSubmitBtn.svelte";
import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapper.svelte";
import { callInferenceApi, getBlobFromUrl, getDemoInputs } from "../../shared/helpers";
import { callInferenceApi, getBlobFromUrl } from "../../shared/helpers";
import { isAssetInput } from "../../shared/inputValidation";

export let apiToken: WidgetProps["apiToken"];
Expand Down Expand Up @@ -64,7 +62,7 @@
}
}

async function getOutput({ withModelLoading = false, isOnLoadCall = false } = {}) {
async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}) {
if (!file && !selectedSampleUrl) {
error = "You must select or record an audio file";
return;
Expand Down Expand Up @@ -126,33 +124,22 @@
throw new TypeError("Invalid output: output must be of type Array<blob:string, label:string, content-type:string>");
}

function applyInputSample(sample: WidgetExampleAssetInput) {
file = null;
function applyInputSample(sample: WidgetExampleAssetInput, opts: ExampleRunOpts = {}) {
filename = sample.example_title ?? "";
fileUrl = sample.src;
if (opts.isPreview) {
output = [];
outputJson = "";
return;
}
file = null;
selectedSampleUrl = sample.src;
getOutput();
}

function previewInputSample(sample: WidgetExampleAssetInput) {
filename = sample.example_title ?? "";
fileUrl = sample.src;
output = [];
outputJson = "";
getOutput(opts.inferenceOpts);
}

onMount(() => {
const [exampleTitle, src] = getDemoInputs(model, ["example_title", "src"]);
if (callApiOnMount && src) {
filename = exampleTitle ?? "";
fileUrl = src;
selectedSampleUrl = src;
getOutput({ isOnLoadCall: true });
}
});
</script>

<WidgetWrapper
{callApiOnMount}
{apiUrl}
{includeCredentials}
{applyInputSample}
Expand All @@ -163,7 +150,6 @@
{modelLoading}
{noTitle}
{outputJson}
{previewInputSample}
validateExample={isAssetInput}
>
<svelte:fragment slot="top">
Expand Down
Loading
Loading