Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Widgets] Don't call api inference if output exists #1063

Merged
merged 2 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions js/src/lib/components/InferenceWidget/shared/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ModelData } from "../../../interfaces/Types";
import type { WidgetExampleOutput } from "./WidgetExample";

export interface WidgetProps {
apiToken?: string;
Expand All @@ -11,15 +12,16 @@ export interface WidgetProps {
isLoggedIn?: boolean;
}

export interface InferenceRunFlags {
export interface InferenceRunOpts<TOutput = WidgetExampleOutput> {
withModelLoading?: boolean;
isOnLoadCall?: boolean;
useCache?: boolean;
exampleOutput?: TOutput;
}

export interface ExampleRunOpts {
isPreview?: boolean;
inferenceOpts?: InferenceRunFlags;
inferenceOpts?: InferenceRunOpts;
}

export type LoadState = "Loadable" | "Loaded" | "TooBig" | "error";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunOpts } from "../../shared/types";
import type { WidgetExample, WidgetExampleAssetInput, WidgetExampleOutputLabels } from "../../shared/WidgetExample";

import WidgetAudioTrack from "../../shared/WidgetAudioTrack/WidgetAudioTrack.svelte";
Expand Down Expand Up @@ -59,7 +59,17 @@
}
}

async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}) {
async function getOutput({
withModelLoading = false,
isOnLoadCall = false,
exampleOutput = undefined,
}: InferenceRunOpts<WidgetExampleOutputLabels> = {}) {
if (exampleOutput) {
output = exampleOutput;
outputJson = "";
return;
}

if (!file && !selectedSampleUrl) {
error = "You must select or record an audio file";
output = [];
Expand Down Expand Up @@ -136,7 +146,8 @@
}
file = null;
selectedSampleUrl = sample.src;
getOutput(opts.inferenceOpts);
const exampleOutput = sample.output;
getOutput({ ...opts.inferenceOpts, exampleOutput });
}

function validateExample(sample: WidgetExample): sample is WidgetExampleAssetInput<WidgetExampleOutputLabels> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunOpts } from "../../shared/types";
import type { WidgetExampleAssetInput } from "../../shared/WidgetExample";

import WidgetAudioTrack from "../../shared/WidgetAudioTrack/WidgetAudioTrack.svelte";
Expand Down Expand Up @@ -62,7 +62,11 @@
}
}

async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}) {
async function getOutput({
withModelLoading = false,
isOnLoadCall = false,
exampleOutput = undefined,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has no effect in this widget

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, because widgets such as AudioToAudio or ImageSegmentation are not expected to have example.output in their model cards

for example here

function applyInputSample(sample: WidgetExampleAssetInput, opts: ExampleRunOpts = {}) {
, if they were expected to have example.output, the typing would have been sample: WidgetExampleAssetInput<WIDGET OUTPUT TYPE>

docs here for the list of supported output types: https://huggingface.co/docs/hub/models-widgets#example-outputs from #978

@julien-c could you confirm this statement

yes, because widgets such as AudioToAudio or ImageSegmentation are not expected to have example.output in their model cards

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure i understand, i'm a bit lost...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the docs (emphasis mine):

each widget example can also optionally describe the corresponding model output, directly in the output property

AudioToAudio or ImageSegmentation should support a file output, i.e. a output.url, same as text-to-image:

image

Copy link
Collaborator Author

@mishig25 mishig25 Oct 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is what I meant:

At the current state of widgets, uses can supply output in their modelcard for any widget. However, only subset of widgets will show/use that example output

AudioToAudio or ImageSegmentation should support a file output

But in the current implementation, AudioToAudio or ImageSegmentation do neither validation nor usage of output

Should I open a PR (extend the current PR) to make all the remaining widgets validate and use output ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed internally, lets merge this PR without the proposed changes above (remaining widgets validate and use output). And iterate in subseq PRs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I open a PR (extend the current PR) to make all the remaining widgets validate and use output ?

Yes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that was the goal to support output for all widgets

}: InferenceRunOpts = {}) {
if (!file && !selectedSampleUrl) {
error = "You must select or record an audio file";
return;
Expand Down Expand Up @@ -134,7 +138,8 @@
}
file = null;
selectedSampleUrl = sample.src;
getOutput(opts.inferenceOpts);
const exampleOutput = sample.output;
getOutput({ ...opts.inferenceOpts, exampleOutput });
}
</script>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunOpts } from "../../shared/types";
import type { WidgetExample, WidgetExampleAssetInput, WidgetExampleOutputText } from "../../shared/WidgetExample";

import WidgetAudioTrack from "../../shared/WidgetAudioTrack/WidgetAudioTrack.svelte";
Expand Down Expand Up @@ -61,7 +61,17 @@
}
}

async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}) {
async function getOutput({
withModelLoading = false,
isOnLoadCall = false,
exampleOutput = undefined,
}: InferenceRunOpts<WidgetExampleOutputText> = {}) {
if (exampleOutput) {
output = exampleOutput.text;
outputJson = "";
return;
}

if (!file && !selectedSampleUrl) {
error = "You must select or record an audio file";
output = "";
Expand Down Expand Up @@ -137,7 +147,8 @@
}
file = null;
selectedSampleUrl = sample.src;
getOutput(opts.inferenceOpts);
const exampleOutput = sample.output;
getOutput({ ...opts.inferenceOpts, exampleOutput });
}

function updateModelLoading(isLoading: boolean, estimatedTime: number = 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunOpts } from "../../shared/types";
import type { WidgetExampleTextInput } from "../../shared/WidgetExample";

import WidgetOutputConvo from "../../shared/WidgetOutputConvo/WidgetOutputConvo.svelte";
Expand Down Expand Up @@ -48,7 +48,11 @@
let outputJson: string;
let text = "";

async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}) {
async function getOutput({
withModelLoading = false,
isOnLoadCall = false,
exampleOutput = undefined,
}: InferenceRunOpts = {}) {
const trimmedText = text.trim();

if (!trimmedText) {
Expand Down Expand Up @@ -143,7 +147,8 @@
if (opts.isPreview) {
return;
}
getOutput(opts.inferenceOpts);
const exampleOutput = sample.output;
getOutput({ ...opts.inferenceOpts, exampleOutput });
}
</script>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunOpts } from "../../shared/types";
import type { WidgetExampleTextInput } from "../../shared/WidgetExample";

import WidgetQuickInput from "../../shared/WidgetQuickInput/WidgetQuickInput.svelte";
Expand Down Expand Up @@ -28,12 +28,16 @@
let outputJson: string;
let text = "";

async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}) {
async function getOutput({
withModelLoading = false,
isOnLoadCall = false,
exampleOutput = undefined,
}: InferenceRunOpts = {}) {
const trimmedText = text.trim();

if (!trimmedText) {
error = "You need to input some text";
output = undefined;
exampleOutput = undefined;
outputJson = "";
return;
}
Expand Down Expand Up @@ -63,7 +67,7 @@
computeTime = "";
error = "";
modelLoading = { isLoading: false, estimatedTime: 0 };
output = undefined;
exampleOutput = undefined;
outputJson = "";

if (res.status === "success") {
Expand Down Expand Up @@ -111,7 +115,8 @@
if (opts.isPreview) {
return;
}
getOutput(opts.inferenceOpts);
const exampleOutput = sample.output;
getOutput({ ...opts.inferenceOpts, exampleOutput });
}
</script>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunOpts } from "../../shared/types";
import type { WidgetExampleTextInput, WidgetExampleOutputLabels, WidgetExample } from "../../shared/WidgetExample";

import WidgetOutputChart from "../../shared/WidgetOutputChart/WidgetOutputChart.svelte";
Expand Down Expand Up @@ -30,7 +30,17 @@
let text = "";
let setTextAreaValue: (text: string) => void;

async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}) {
async function getOutput({
withModelLoading = false,
isOnLoadCall = false,
exampleOutput = undefined,
}: InferenceRunOpts<WidgetExampleOutputLabels> = {}) {
if (exampleOutput) {
output = exampleOutput;
outputJson = "";
return;
}

const trimmedText = text.trim();

if (!trimmedText) {
Expand Down Expand Up @@ -112,7 +122,8 @@
}
return;
}
getOutput(opts.inferenceOpts);
const exampleOutput = sample.output;
getOutput({ ...opts.inferenceOpts, exampleOutput });
}

function validateExample(sample: WidgetExample): sample is WidgetExampleTextInput<WidgetExampleOutputLabels> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import type { WidgetProps, InferenceRunFlags, ExampleRunOpts } from "../../shared/types";
import type { WidgetProps, InferenceRunOpts, ExampleRunOpts } from "../../shared/types";
import type { WidgetExample, WidgetExampleAssetInput, WidgetExampleOutputLabels } from "../../shared/WidgetExample";

import WidgetFileInput from "../../shared/WidgetFileInput/WidgetFileInput.svelte";
Expand Down Expand Up @@ -36,7 +36,7 @@

async function getOutput(
file: File | Blob,
{ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}
{ withModelLoading = false, isOnLoadCall = false, exampleOutput = undefined }: InferenceRunOpts = {}
) {
if (!file) {
return;
Expand Down Expand Up @@ -108,7 +108,8 @@
return;
}
const blob = await getBlobFromUrl(imgSrc);
getOutput(blob, opts.inferenceOpts);
const exampleOutput = sample.output;
getOutput(blob, { ...opts.inferenceOpts, exampleOutput });
}

function validateExample(sample: WidgetExample): sample is WidgetExampleAssetInput<WidgetExampleOutputLabels> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import type { WidgetProps, ImageSegment, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetProps, ImageSegment, ExampleRunOpts, InferenceRunOpts } from "../../shared/types";
import type { WidgetExampleAssetInput } from "../../shared/WidgetExample";

import { onMount } from "svelte";
Expand Down Expand Up @@ -52,7 +52,7 @@

async function getOutput(
file: File | Blob,
{ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}
{ withModelLoading = false, isOnLoadCall = false, exampleOutput = undefined }: InferenceRunOpts = {}
) {
if (!file) {
return;
Expand Down Expand Up @@ -217,7 +217,8 @@
return;
}
const blob = await getBlobFromUrl(imgSrc);
getOutput(blob, opts.inferenceOpts);
const exampleOutput = sample.output;
getOutput(blob, { ...opts.inferenceOpts, exampleOutput });
}

onMount(() => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunOpts } from "../../shared/types";
import type { WidgetExampleAssetAndPromptInput } from "../../shared/WidgetExample";

import WidgetFileInput from "../../shared/WidgetFileInput/WidgetFileInput.svelte";
Expand Down Expand Up @@ -70,10 +70,15 @@
const res = await fetch(imgSrc);
const blob = await res.blob();
await updateImageBase64(blob);
getOutput(opts.inferenceOpts);
const exampleOutput = sample.output;
getOutput({ ...opts.inferenceOpts, exampleOutput });
}

async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}) {
async function getOutput({
withModelLoading = false,
isOnLoadCall = false,
exampleOutput = undefined,
}: InferenceRunOpts = {}) {
const trimmedPrompt = prompt.trim();

if (!imageBase64) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunOpts } from "../../shared/types";
import type { WidgetExampleAssetInput } from "../../shared/WidgetExample";

import WidgetFileInput from "../../shared/WidgetFileInput/WidgetFileInput.svelte";
Expand Down Expand Up @@ -35,7 +35,7 @@

async function getOutput(
file: File | Blob,
{ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}
{ withModelLoading = false, isOnLoadCall = false, exampleOutput = undefined }: InferenceRunOpts = {}
) {
if (!file) {
return;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import type { WidgetProps, DetectedObject, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetProps, DetectedObject, ExampleRunOpts, InferenceRunOpts } from "../../shared/types";
import type { WidgetExampleAssetInput } from "../../shared/WidgetExample";

import { mod } from "../../../../utils/ViewUtils";
Expand Down Expand Up @@ -40,7 +40,7 @@

async function getOutput(
file: File | Blob,
{ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}
{ withModelLoading = false, isOnLoadCall = false, exampleOutput = undefined }: InferenceRunOpts = {}
) {
if (!file) {
return;
Expand Down Expand Up @@ -136,7 +136,8 @@
return;
}
const blob = await getBlobFromUrl(imgSrc);
getOutput(blob, opts.inferenceOpts);
const exampleOutput = sample.output;
getOutput(blob, { ...opts.inferenceOpts, exampleOutput });
}
</script>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import type { WidgetProps, ExampleRunOpts, InferenceRunFlags } from "../../shared/types";
import type { WidgetProps, ExampleRunOpts, InferenceRunOpts } from "../../shared/types";
import type {
WidgetExample,
WidgetExampleOutputAnswerScore,
Expand Down Expand Up @@ -34,7 +34,11 @@
let question = "";
let setTextAreaValue: (text: string) => void;

async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunFlags = {}) {
async function getOutput({
withModelLoading = false,
isOnLoadCall = false,
exampleOutput = undefined,
}: InferenceRunOpts = {}) {
const trimmedQuestion = question.trim();
const trimmedContext = context.trim();

Expand Down Expand Up @@ -113,7 +117,8 @@
if (opts.isPreview) {
return;
}
getOutput(opts.inferenceOpts);
const exampleOutput = sample.output;
getOutput({ ...opts.inferenceOpts, exampleOutput });
}

function validateExample(
Expand Down
Loading
Loading