-
Notifications
You must be signed in to change notification settings - Fork 264
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
403 additions
and
1 deletion.
There are no files selected for viewing
218 changes: 218 additions & 0 deletions
218
js/src/lib/components/InferenceWidget/shared/WidgetWrapperV2/WidgetWrapperV2.svelte
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
<script lang="ts"> | ||
import type { WidgetProps, ModelLoadInfo, ExampleRunOpts, WidgetInput, WidgetOutput } from "../types"; | ||
import type { WidgetExample, WidgetExampleAttribute } from "../WidgetExample"; | ||
type TWidgetExample = $$Generic<WidgetExample>; | ||
import { onMount } from "svelte"; | ||
import IconCross from "../../../Icons/IconCross.svelte"; | ||
import WidgetInputSamples from "../WidgetInputSamples/WidgetInputSamples.svelte"; | ||
import WidgetInputSamplesGroup from "../WidgetInputSamplesGroup/WidgetInputSamplesGroup.svelte"; | ||
import WidgetFooter from "../WidgetFooter/WidgetFooter.svelte"; | ||
import WidgetHeader from "../WidgetHeader/WidgetHeader.svelte"; | ||
import WidgetInfo from "../WidgetInfo/WidgetInfo.svelte"; | ||
import WidgetModelLoading from "../WidgetModelLoading/WidgetModelLoading.svelte"; | ||
import { callInferenceApi, getModelLoadInfo, getQueryParamVal, getWidgetExample } from "../helpers"; | ||
import { modelLoadStates } from "../../stores"; | ||
import { InferenceDisplayability } from "../../../../interfaces/InferenceDisplayability"; | ||
export let apiUrl: string; | ||
export let callApiOnMount: WidgetProps["callApiOnMount"]; | ||
export let apiToken: WidgetProps["apiToken"]; | ||
export let isLoading = false; | ||
export let model: WidgetProps["model"]; | ||
export let includeCredentials: WidgetProps["includeCredentials"]; | ||
export let noTitle = false; | ||
export let exampleQueryParams: WidgetExampleAttribute[] = []; | ||
export let widgetInput: WidgetInput; | ||
export let widgetOutput: WidgetOutput; | ||
let modelLoading = { | ||
isLoading: false, | ||
estimatedTime: 0, | ||
}; | ||
let computeTime: string; | ||
let outputJson: string; | ||
let error: string; | ||
let warning: string; | ||
let isDisabled = model.inference !== InferenceDisplayability.Yes && model.pipeline_tag !== "reinforcement-learning"; | ||
let isMaximized = false; | ||
let modelLoadInfo: ModelLoadInfo | undefined = undefined; | ||
let selectedInputGroup: string; | ||
let modelTooBig = false; | ||
interface ExamplesGroup { | ||
group: string; | ||
inputSamples: TWidgetExample[]; | ||
} | ||
const allInputSamples = (model.widgetData ?? []) | ||
.filter(validateExample) | ||
.sort((sample1, sample2) => (sample2.example_title ? 1 : 0) - (sample1.example_title ? 1 : 0)) | ||
.map((sample, idx) => ({ | ||
example_title: `Example ${++idx}`, | ||
group: "Group 1", | ||
...sample, | ||
})); | ||
let inputSamples = !isDisabled ? allInputSamples : allInputSamples.filter(sample => sample.output !== undefined); | ||
let inputGroups = getExamplesGroups(); | ||
$: selectedInputSamples = | ||
inputGroups.length === 1 ? inputGroups[0] : inputGroups.find(({ group }) => group === selectedInputGroup); | ||
function getExamplesGroups(): ExamplesGroup[] { | ||
const inputGroups: ExamplesGroup[] = []; | ||
for (const inputSample of inputSamples) { | ||
const groupExists = inputGroups.find(({ group }) => group === inputSample.group); | ||
if (!groupExists) { | ||
inputGroups.push({ group: inputSample.group as string, inputSamples: [] }); | ||
} | ||
inputGroups.find(({ group }) => group === inputSample.group)?.inputSamples.push(inputSample); | ||
} | ||
return inputGroups; | ||
} | ||
function validateExample(sample: WidgetExample): sample is TWidgetExample{ | ||
return widgetInput.validateExample(sample) && (!sample.output || widgetOutput.validateExample(sample.output)) | ||
} | ||
function applyInputSample(sample: TWidgetExample, opts?: ExampleRunOpts){ | ||
widgetInput.applyExample(sample, opts); | ||
widgetOutput.applyExample(sample, opts); | ||
} | ||
export async function getOutput({ | ||
Check failure on line 87 in js/src/lib/components/InferenceWidget/shared/WidgetWrapperV2/WidgetWrapperV2.svelte GitHub Actions / build
|
||
withModelLoading = false, | ||
isOnLoadCall = false, | ||
}) { | ||
let _requestBody; | ||
try { | ||
_requestBody = widgetInput.getInferenceInput(); | ||
} catch (err) { | ||
// errors such as: input can't be empty etc | ||
error = err; | ||
} | ||
const requestBody = _requestBody; | ||
isLoading = true; | ||
const res = await callInferenceApi( | ||
apiUrl, | ||
model.id, | ||
requestBody, | ||
apiToken, | ||
withModelLoading, | ||
includeCredentials, | ||
isOnLoadCall | ||
); | ||
isLoading = false; | ||
// Reset values | ||
computeTime = ""; | ||
error = ""; | ||
warning = ""; | ||
modelLoading = { isLoading: false, estimatedTime: 0 }; | ||
outputJson = ""; | ||
if (res.status === "success") { | ||
computeTime = res.computeTime; | ||
widgetOutput.showInferenceOutput(body); | ||
outputJson = res.outputJson; | ||
if (output.length === 0) { | ||
warning = "Inference output was empty"; | ||
} | ||
} else if (res.status === "loading-model") { | ||
modelLoading = { | ||
isLoading: true, | ||
estimatedTime: res.estimatedTime, | ||
}; | ||
getOutput({ withModelLoading: true }); | ||
} else if (res.status === "error") { | ||
error = res.error; | ||
} | ||
} | ||
onMount(() => { | ||
(async () => { | ||
modelLoadInfo = await getModelLoadInfo(apiUrl, model.id, includeCredentials); | ||
$modelLoadStates[model.id] = modelLoadInfo; | ||
modelTooBig = modelLoadInfo?.state === "TooBig"; | ||
if (modelTooBig) { | ||
// disable the widget | ||
isDisabled = true; | ||
inputSamples = allInputSamples.filter(sample => sample.output !== undefined); | ||
inputGroups = getExamplesGroups(); | ||
} | ||
const exampleFromQueryParams = {} as TWidgetExample; | ||
for (const key of exampleQueryParams) { | ||
const val = getQueryParamVal(key); | ||
if (val) { | ||
exampleFromQueryParams[key] = val; | ||
} | ||
} | ||
if (Object.keys(exampleFromQueryParams).length) { | ||
// run widget example from query params | ||
applyInputSample(exampleFromQueryParams); | ||
} else { | ||
// run random widget example | ||
const example = getWidgetExample<TWidgetExample>(model, validateExample); | ||
if (callApiOnMount && example) { | ||
applyInputSample(example, { inferenceOpts: { isOnLoadCall: true } }); | ||
} | ||
} | ||
})(); | ||
}); | ||
function onClickMaximizeBtn() { | ||
isMaximized = !isMaximized; | ||
} | ||
</script> | ||
|
||
{#if isDisabled && !inputSamples.length} | ||
<WidgetHeader pipeline={model.pipeline_tag} noTitle={true} /> | ||
<WidgetInfo {model} {computeTime} {error} {modelLoadInfo} {modelTooBig} /> | ||
{:else} | ||
<div | ||
class="flex w-full max-w-full flex-col | ||
{isMaximized ? 'fixed inset-0 z-20 bg-white p-12' : ''} | ||
{!modelLoadInfo ? 'hidden' : ''}" | ||
> | ||
{#if isMaximized} | ||
<button class="absolute right-12 top-6" on:click={onClickMaximizeBtn}> | ||
<IconCross classNames="text-xl text-gray-500 hover:text-black" /> | ||
</button> | ||
{/if} | ||
<WidgetHeader {noTitle} pipeline={model.pipeline_tag} {isDisabled}> | ||
{#if !!inputGroups.length} | ||
<div class="ml-auto flex gap-x-1"> | ||
<!-- Show samples selector when there are more than one sample --> | ||
{#if inputGroups.length > 1} | ||
<WidgetInputSamplesGroup | ||
bind:selectedInputGroup | ||
{isLoading} | ||
inputGroups={inputGroups.map(({ group }) => group)} | ||
/> | ||
{/if} | ||
<WidgetInputSamples | ||
classNames={!selectedInputSamples ? "opacity-50 pointer-events-none" : ""} | ||
{isLoading} | ||
inputSamples={selectedInputSamples?.inputSamples ?? []} | ||
{applyInputSample} | ||
/> | ||
</div> | ||
{/if} | ||
</WidgetHeader> | ||
<slot name="input" {isDisabled} /> | ||
<WidgetInfo {model} {computeTime} {error} {warning} {modelLoadInfo} {modelTooBig} /> | ||
{#if modelLoading.isLoading} | ||
<WidgetModelLoading estimatedTime={modelLoading.estimatedTime} /> | ||
{/if} | ||
<slot name="output" /> | ||
<WidgetFooter {onClickMaximizeBtn} {outputJson} {isDisabled} /> | ||
</div> | ||
{/if} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
41 changes: 41 additions & 0 deletions
41
...ts/InferenceWidget/widgets/AudioClassificationWidgetV2/AudioClassificationWidgetV2.svelte
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
<script lang="ts"> | ||
import type { WidgetProps } from "../../shared/types"; | ||
import type { WidgetOutput } from "../../wigdetsOutput/Types"; | ||
import type { WidgetInput } from "../../widgetsInput/Types"; | ||
import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapperV2.svelte"; | ||
import WidgetInputAudio from "../../widgetsInput/WidgetInputAudio/WidgetInputAudio.svelte"; | ||
import WidgetOutputClassification from "../../wigdetsOutput/WidgetOutputClassification/WidgetOutputClassification.svelte"; | ||
export let apiToken: WidgetProps["apiToken"]; | ||
export let apiUrl: WidgetProps["apiUrl"]; | ||
export let callApiOnMount: WidgetProps["callApiOnMount"]; | ||
export let model: WidgetProps["model"]; | ||
export let noTitle: WidgetProps["noTitle"]; | ||
export let includeCredentials: WidgetProps["includeCredentials"]; | ||
export let widgetInput: WidgetInput; | ||
export let widgetOutput: WidgetOutput; | ||
export let isLoading = false; | ||
export let getOutput: ({withModelLoading = false,isOnLoadCall = false,}) => Promise<void>; | ||
</script> | ||
|
||
<WidgetWrapper | ||
{apiToken} | ||
{callApiOnMount} | ||
{apiUrl} | ||
{includeCredentials} | ||
{isLoading} | ||
{model} | ||
{noTitle} | ||
{widgetInput} | ||
{widgetOutput} | ||
bind:getOutput | ||
> | ||
<svelte:fragment slot="input" let:isDisabled> | ||
<WidgetInputAudio bind:widgetInput {isLoading} {isDisabled} on:click={() => getOutput()} /> | ||
</svelte:fragment> | ||
<svelte:fragment slot="output"> | ||
<WidgetOutputClassification bind:widgetOutput /> | ||
</svelte:fragment> | ||
</WidgetWrapper> |
98 changes: 98 additions & 0 deletions
98
js/src/lib/components/InferenceWidget/widgetsInput/WidgetInputAudio/WidgetInputAudio.svelte
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
<script lang="ts"> | ||
import type { ExampleRunOpts, WidgetInput } from "../../shared/types"; | ||
import type { WidgetExample, WidgetExampleAssetInput } from "../../shared/WidgetExample"; | ||
import WidgetAudioTrack from "../../shared/WidgetAudioTrack/WidgetAudioTrack.svelte"; | ||
import WidgetFileInput from "../../shared/WidgetFileInput/WidgetFileInput.svelte"; | ||
import WidgetRecorder from "../../shared/WidgetRecorder/WidgetRecorder.svelte"; | ||
import WidgetSubmitBtn from "../../shared/WidgetSubmitBtn/WidgetSubmitBtn.svelte"; | ||
import { getBlobFromUrl } from "../../shared/helpers"; | ||
import { createEventDispatcher } from "svelte"; | ||
export let isDisabled = false; | ||
export let isLoading = false; | ||
let error: string = ""; | ||
let file: Blob | File | null = null; | ||
let filename: string = ""; | ||
let fileUrl: string; | ||
let isRecording = false; | ||
let selectedSampleUrl = ""; | ||
let warning: string = ""; | ||
function onRecordStart() { | ||
file = null; | ||
filename = ""; | ||
fileUrl = ""; | ||
isRecording = true; | ||
} | ||
function onRecordError(err: string) { | ||
error = err; | ||
} | ||
function onSelectFile(updatedFile: Blob | File) { | ||
isRecording = false; | ||
selectedSampleUrl = ""; | ||
if (updatedFile.size !== 0) { | ||
const date = new Date(); | ||
const time = date.toLocaleTimeString("en-US"); | ||
filename = "name" in updatedFile ? updatedFile.name : `Audio recorded from browser [${time}]`; | ||
file = updatedFile; | ||
fileUrl = URL.createObjectURL(file); | ||
} | ||
} | ||
const dispatch = createEventDispatcher<{ click: void }>(); | ||
export const widgetInput: WidgetInput = { | ||
validateExample<TOutput>(sample: WidgetExample<TOutput>): sample is WidgetExampleAssetInput<TOutput> { | ||
return "src" in sample; | ||
}, | ||
applyExample(sample: WidgetExampleAssetInput, opts: ExampleRunOpts) { | ||
filename = sample.example_title!; | ||
fileUrl = sample.src; | ||
if (opts.isPreview) { | ||
return; | ||
} | ||
file = null; | ||
selectedSampleUrl = sample.src; | ||
}, | ||
async getInferenceInput() { | ||
if (!file && !selectedSampleUrl) { | ||
error = "You must select or record an audio file"; | ||
return; | ||
} | ||
if (!file && selectedSampleUrl) { | ||
file = await getBlobFromUrl(selectedSampleUrl); | ||
} | ||
const requestBody = { file }; | ||
return requestBody; | ||
}, | ||
}; | ||
</script> | ||
|
||
<form> | ||
<div class="flex flex-wrap items-center {isDisabled ? 'pointer-events-none hidden opacity-50' : ''}"> | ||
<WidgetFileInput accept="audio/*" classNames="mt-1.5 mr-2" {onSelectFile} /> | ||
<span class="mr-2 mt-1.5">or</span> | ||
<WidgetRecorder classNames="mt-1.5" {onRecordStart} onRecordStop={onSelectFile} onError={onRecordError} /> | ||
</div> | ||
{#if fileUrl} | ||
<WidgetAudioTrack classNames="mt-3" label={filename} src={fileUrl} /> | ||
{/if} | ||
<WidgetSubmitBtn | ||
classNames="mt-2" | ||
isDisabled={isRecording || isDisabled} | ||
{isLoading} | ||
onClick={() => { | ||
dispatch("click"); | ||
}} | ||
/> | ||
{#if warning} | ||
<div class="alert alert-warning mt-2">{warning}</div> | ||
{/if} | ||
</form> |
Oops, something went wrong.