Skip to content

Commit

Permalink
Proposal: refactor widgets
Browse files Browse the repository at this point in the history
  • Loading branch information
mishig25 committed Nov 7, 2023
1 parent f79d817 commit ff2c1b2
Show file tree
Hide file tree
Showing 5 changed files with 403 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
<script lang="ts">
import type { WidgetProps, ModelLoadInfo, ExampleRunOpts, WidgetInput, WidgetOutput } from "../types";
import type { WidgetExample, WidgetExampleAttribute } from "../WidgetExample";
type TWidgetExample = $$Generic<WidgetExample>;
import { onMount } from "svelte";
import IconCross from "../../../Icons/IconCross.svelte";
import WidgetInputSamples from "../WidgetInputSamples/WidgetInputSamples.svelte";
import WidgetInputSamplesGroup from "../WidgetInputSamplesGroup/WidgetInputSamplesGroup.svelte";
import WidgetFooter from "../WidgetFooter/WidgetFooter.svelte";
import WidgetHeader from "../WidgetHeader/WidgetHeader.svelte";
import WidgetInfo from "../WidgetInfo/WidgetInfo.svelte";
import WidgetModelLoading from "../WidgetModelLoading/WidgetModelLoading.svelte";
import { callInferenceApi, getModelLoadInfo, getQueryParamVal, getWidgetExample } from "../helpers";
import { modelLoadStates } from "../../stores";
import { InferenceDisplayability } from "../../../../interfaces/InferenceDisplayability";
export let apiUrl: string;
export let callApiOnMount: WidgetProps["callApiOnMount"];
export let apiToken: WidgetProps["apiToken"];
export let isLoading = false;
export let model: WidgetProps["model"];
export let includeCredentials: WidgetProps["includeCredentials"];
export let noTitle = false;
export let exampleQueryParams: WidgetExampleAttribute[] = [];
export let widgetInput: WidgetInput;
export let widgetOutput: WidgetOutput;
let modelLoading = {
isLoading: false,
estimatedTime: 0,
};
let computeTime: string;
let outputJson: string;
let error: string;
let warning: string;
let isDisabled = model.inference !== InferenceDisplayability.Yes && model.pipeline_tag !== "reinforcement-learning";
let isMaximized = false;
let modelLoadInfo: ModelLoadInfo | undefined = undefined;
let selectedInputGroup: string;
let modelTooBig = false;
interface ExamplesGroup {
group: string;
inputSamples: TWidgetExample[];
}
const allInputSamples = (model.widgetData ?? [])
.filter(validateExample)
.sort((sample1, sample2) => (sample2.example_title ? 1 : 0) - (sample1.example_title ? 1 : 0))
.map((sample, idx) => ({
example_title: `Example ${++idx}`,
group: "Group 1",
...sample,
}));
let inputSamples = !isDisabled ? allInputSamples : allInputSamples.filter(sample => sample.output !== undefined);
let inputGroups = getExamplesGroups();
$: selectedInputSamples =
inputGroups.length === 1 ? inputGroups[0] : inputGroups.find(({ group }) => group === selectedInputGroup);
function getExamplesGroups(): ExamplesGroup[] {
const inputGroups: ExamplesGroup[] = [];
for (const inputSample of inputSamples) {
const groupExists = inputGroups.find(({ group }) => group === inputSample.group);
if (!groupExists) {
inputGroups.push({ group: inputSample.group as string, inputSamples: [] });
}
inputGroups.find(({ group }) => group === inputSample.group)?.inputSamples.push(inputSample);
}
return inputGroups;
}
function validateExample(sample: WidgetExample): sample is TWidgetExample{
return widgetInput.validateExample(sample) && (!sample.output || widgetOutput.validateExample(sample.output))
}
function applyInputSample(sample: TWidgetExample, opts?: ExampleRunOpts){
widgetInput.applyExample(sample, opts);
widgetOutput.applyExample(sample, opts);
}
export async function getOutput({

Check failure on line 87 in js/src/lib/components/InferenceWidget/shared/WidgetWrapperV2/WidgetWrapperV2.svelte

View workflow job for this annotation

GitHub Actions / build

Missing return type on function

Check failure on line 87 in js/src/lib/components/InferenceWidget/shared/WidgetWrapperV2/WidgetWrapperV2.svelte

View workflow job for this annotation

GitHub Actions / build

Object pattern argument should be typed
withModelLoading = false,
isOnLoadCall = false,
}) {
let _requestBody;
try {
_requestBody = widgetInput.getInferenceInput();
} catch (err) {
// errors such as: input can't be empty etc
error = err;
}
const requestBody = _requestBody;
isLoading = true;
const res = await callInferenceApi(
apiUrl,
model.id,
requestBody,
apiToken,
withModelLoading,
includeCredentials,
isOnLoadCall
);
isLoading = false;
// Reset values
computeTime = "";
error = "";
warning = "";
modelLoading = { isLoading: false, estimatedTime: 0 };
outputJson = "";
if (res.status === "success") {
computeTime = res.computeTime;
widgetOutput.showInferenceOutput(body);
outputJson = res.outputJson;
if (output.length === 0) {
warning = "Inference output was empty";
}
} else if (res.status === "loading-model") {
modelLoading = {
isLoading: true,
estimatedTime: res.estimatedTime,
};
getOutput({ withModelLoading: true });
} else if (res.status === "error") {
error = res.error;
}
}
onMount(() => {
(async () => {
modelLoadInfo = await getModelLoadInfo(apiUrl, model.id, includeCredentials);
$modelLoadStates[model.id] = modelLoadInfo;
modelTooBig = modelLoadInfo?.state === "TooBig";
if (modelTooBig) {
// disable the widget
isDisabled = true;
inputSamples = allInputSamples.filter(sample => sample.output !== undefined);
inputGroups = getExamplesGroups();
}
const exampleFromQueryParams = {} as TWidgetExample;
for (const key of exampleQueryParams) {
const val = getQueryParamVal(key);
if (val) {
exampleFromQueryParams[key] = val;
}
}
if (Object.keys(exampleFromQueryParams).length) {
// run widget example from query params
applyInputSample(exampleFromQueryParams);
} else {
// run random widget example
const example = getWidgetExample<TWidgetExample>(model, validateExample);
if (callApiOnMount && example) {
applyInputSample(example, { inferenceOpts: { isOnLoadCall: true } });
}
}
})();
});
function onClickMaximizeBtn() {
isMaximized = !isMaximized;
}
</script>

{#if isDisabled && !inputSamples.length}
<WidgetHeader pipeline={model.pipeline_tag} noTitle={true} />
<WidgetInfo {model} {computeTime} {error} {modelLoadInfo} {modelTooBig} />
{:else}
<div
class="flex w-full max-w-full flex-col
{isMaximized ? 'fixed inset-0 z-20 bg-white p-12' : ''}
{!modelLoadInfo ? 'hidden' : ''}"
>
{#if isMaximized}
<button class="absolute right-12 top-6" on:click={onClickMaximizeBtn}>
<IconCross classNames="text-xl text-gray-500 hover:text-black" />
</button>
{/if}
<WidgetHeader {noTitle} pipeline={model.pipeline_tag} {isDisabled}>
{#if !!inputGroups.length}
<div class="ml-auto flex gap-x-1">
<!-- Show samples selector when there are more than one sample -->
{#if inputGroups.length > 1}
<WidgetInputSamplesGroup
bind:selectedInputGroup
{isLoading}
inputGroups={inputGroups.map(({ group }) => group)}
/>
{/if}
<WidgetInputSamples
classNames={!selectedInputSamples ? "opacity-50 pointer-events-none" : ""}
{isLoading}
inputSamples={selectedInputSamples?.inputSamples ?? []}
{applyInputSample}
/>
</div>
{/if}
</WidgetHeader>
<slot name="input" {isDisabled} />
<WidgetInfo {model} {computeTime} {error} {warning} {modelLoadInfo} {modelTooBig} />
{#if modelLoading.isLoading}
<WidgetModelLoading estimatedTime={modelLoading.estimatedTime} />
{/if}
<slot name="output" />
<WidgetFooter {onClickMaximizeBtn} {outputJson} {isDisabled} />
</div>
{/if}
19 changes: 18 additions & 1 deletion js/src/lib/components/InferenceWidget/shared/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { ModelData } from "../../../interfaces/Types";
import type { WidgetExampleOutput } from "./WidgetExample";
import type { WidgetExample, WidgetExampleOutput } from "./WidgetExample";

export interface WidgetProps {
apiToken?: string;
Expand Down Expand Up @@ -66,3 +66,20 @@ export interface ImageSegment {
imgData?: ImageData;
bitmap?: ImageBitmap;
}

export interface WidgetInput {
// example functions
validateExample: <TOutput>(sample: WidgetExample<TOutput>) => sample is WidgetExample<TOutput>;
applyExample: (sample: any, opts: ExampleRunOpts) => void;
// ingerence api functions
getInferenceInput: () => Promise<any>;
}

export interface WidgetOutput {
// example functions
validateExample: (arg: unknown) => boolean;
applyExample: (sample: any, opts: ExampleRunOpts) => void;
// ingerence api functions
validateInferenceOutput: (arg: unknown) => boolean;
showInferenceOutput: (res: any) => Promise<any>;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
<script lang="ts">
import type { WidgetProps } from "../../shared/types";
import type { WidgetOutput } from "../../wigdetsOutput/Types";
import type { WidgetInput } from "../../widgetsInput/Types";
import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapperV2.svelte";

Check failure on line 6 in js/src/lib/components/InferenceWidget/widgets/AudioClassificationWidgetV2/AudioClassificationWidgetV2.svelte

View workflow job for this annotation

GitHub Actions / build

Unable to resolve path to module '../../shared/WidgetWrapper/WidgetWrapperV2.svelte'
import WidgetInputAudio from "../../widgetsInput/WidgetInputAudio/WidgetInputAudio.svelte";
import WidgetOutputClassification from "../../wigdetsOutput/WidgetOutputClassification/WidgetOutputClassification.svelte";
export let apiToken: WidgetProps["apiToken"];
export let apiUrl: WidgetProps["apiUrl"];
export let callApiOnMount: WidgetProps["callApiOnMount"];
export let model: WidgetProps["model"];
export let noTitle: WidgetProps["noTitle"];
export let includeCredentials: WidgetProps["includeCredentials"];
export let widgetInput: WidgetInput;
export let widgetOutput: WidgetOutput;
export let isLoading = false;
export let getOutput: ({withModelLoading = false,isOnLoadCall = false,}) => Promise<void>;
</script>

<WidgetWrapper
{apiToken}
{callApiOnMount}
{apiUrl}
{includeCredentials}
{isLoading}
{model}
{noTitle}
{widgetInput}
{widgetOutput}
bind:getOutput
>
<svelte:fragment slot="input" let:isDisabled>
<WidgetInputAudio bind:widgetInput {isLoading} {isDisabled} on:click={() => getOutput()} />
</svelte:fragment>
<svelte:fragment slot="output">
<WidgetOutputClassification bind:widgetOutput />
</svelte:fragment>
</WidgetWrapper>
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
<script lang="ts">
import type { ExampleRunOpts, WidgetInput } from "../../shared/types";
import type { WidgetExample, WidgetExampleAssetInput } from "../../shared/WidgetExample";
import WidgetAudioTrack from "../../shared/WidgetAudioTrack/WidgetAudioTrack.svelte";
import WidgetFileInput from "../../shared/WidgetFileInput/WidgetFileInput.svelte";
import WidgetRecorder from "../../shared/WidgetRecorder/WidgetRecorder.svelte";
import WidgetSubmitBtn from "../../shared/WidgetSubmitBtn/WidgetSubmitBtn.svelte";
import { getBlobFromUrl } from "../../shared/helpers";
import { createEventDispatcher } from "svelte";
export let isDisabled = false;
export let isLoading = false;
let error: string = "";
let file: Blob | File | null = null;
let filename: string = "";
let fileUrl: string;
let isRecording = false;
let selectedSampleUrl = "";
let warning: string = "";
function onRecordStart() {
file = null;
filename = "";
fileUrl = "";
isRecording = true;
}
function onRecordError(err: string) {
error = err;
}
function onSelectFile(updatedFile: Blob | File) {
isRecording = false;
selectedSampleUrl = "";
if (updatedFile.size !== 0) {
const date = new Date();
const time = date.toLocaleTimeString("en-US");
filename = "name" in updatedFile ? updatedFile.name : `Audio recorded from browser [${time}]`;
file = updatedFile;
fileUrl = URL.createObjectURL(file);
}
}
const dispatch = createEventDispatcher<{ click: void }>();
export const widgetInput: WidgetInput = {
validateExample<TOutput>(sample: WidgetExample<TOutput>): sample is WidgetExampleAssetInput<TOutput> {
return "src" in sample;
},
applyExample(sample: WidgetExampleAssetInput, opts: ExampleRunOpts) {
filename = sample.example_title!;
fileUrl = sample.src;
if (opts.isPreview) {
return;
}
file = null;
selectedSampleUrl = sample.src;
},
async getInferenceInput() {
if (!file && !selectedSampleUrl) {
error = "You must select or record an audio file";
return;
}
if (!file && selectedSampleUrl) {
file = await getBlobFromUrl(selectedSampleUrl);
}
const requestBody = { file };
return requestBody;
},
};
</script>

<form>
<div class="flex flex-wrap items-center {isDisabled ? 'pointer-events-none hidden opacity-50' : ''}">
<WidgetFileInput accept="audio/*" classNames="mt-1.5 mr-2" {onSelectFile} />
<span class="mr-2 mt-1.5">or</span>
<WidgetRecorder classNames="mt-1.5" {onRecordStart} onRecordStop={onSelectFile} onError={onRecordError} />
</div>
{#if fileUrl}
<WidgetAudioTrack classNames="mt-3" label={filename} src={fileUrl} />
{/if}
<WidgetSubmitBtn
classNames="mt-2"
isDisabled={isRecording || isDisabled}
{isLoading}
onClick={() => {
dispatch("click");
}}
/>
{#if warning}
<div class="alert alert-warning mt-2">{warning}</div>
{/if}
</form>
Loading

0 comments on commit ff2c1b2

Please sign in to comment.