diff --git a/js/src/lib/components/InferenceWidget/shared/types.ts b/js/src/lib/components/InferenceWidget/shared/types.ts index 8931029ce..b08003e1e 100644 --- a/js/src/lib/components/InferenceWidget/shared/types.ts +++ b/js/src/lib/components/InferenceWidget/shared/types.ts @@ -1,4 +1,5 @@ import type { ModelData } from "../../../interfaces/Types"; +import type { WidgetExampleOutput } from "./WidgetExample"; export interface WidgetProps { apiToken?: string; @@ -11,10 +12,11 @@ export interface WidgetProps { isLoggedIn?: boolean; } -export interface InferenceRunOpts { +export interface InferenceRunOpts { withModelLoading?: boolean; isOnLoadCall?: boolean; useCache?: boolean; + exampleOutput?: TOutput; } export interface ExampleRunOpts { diff --git a/js/src/lib/components/InferenceWidget/widgets/AudioClassificationWidget/AudioClassificationWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/AudioClassificationWidget/AudioClassificationWidget.svelte index 16b7fd150..25b9c374e 100644 --- a/js/src/lib/components/InferenceWidget/widgets/AudioClassificationWidget/AudioClassificationWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/AudioClassificationWidget/AudioClassificationWidget.svelte @@ -59,7 +59,17 @@ } } - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { + if (exampleOutput) { + output = exampleOutput; + outputJson = ""; + return; + } + if (!file && !selectedSampleUrl) { error = "You must select or record an audio file"; output = []; @@ -136,7 +146,8 @@ } file = null; selectedSampleUrl = sample.src; - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } function validateExample(sample: WidgetExample): sample is WidgetExampleAssetInput { diff --git a/js/src/lib/components/InferenceWidget/widgets/AudioToAudioWidget/AudioToAudioWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/AudioToAudioWidget/AudioToAudioWidget.svelte index 522687ce2..f4ada5667 100644 --- a/js/src/lib/components/InferenceWidget/widgets/AudioToAudioWidget/AudioToAudioWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/AudioToAudioWidget/AudioToAudioWidget.svelte @@ -62,7 +62,11 @@ } } - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { if (!file && !selectedSampleUrl) { error = "You must select or record an audio file"; return; @@ -134,7 +138,8 @@ } file = null; selectedSampleUrl = sample.src; - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } diff --git a/js/src/lib/components/InferenceWidget/widgets/AutomaticSpeechRecognitionWidget/AutomaticSpeechRecognitionWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/AutomaticSpeechRecognitionWidget/AutomaticSpeechRecognitionWidget.svelte index d29e9b6a1..18c32d58c 100644 --- a/js/src/lib/components/InferenceWidget/widgets/AutomaticSpeechRecognitionWidget/AutomaticSpeechRecognitionWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/AutomaticSpeechRecognitionWidget/AutomaticSpeechRecognitionWidget.svelte @@ -61,7 +61,17 @@ } } - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { + if (exampleOutput) { + output = exampleOutput.text; + outputJson = ""; + return; + } + if (!file && !selectedSampleUrl) { error = "You must select or record an audio file"; output = ""; @@ -137,7 +147,8 @@ } file = null; selectedSampleUrl = sample.src; - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } function updateModelLoading(isLoading: boolean, estimatedTime: number = 0) { diff --git a/js/src/lib/components/InferenceWidget/widgets/ConversationalWidget/ConversationalWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/ConversationalWidget/ConversationalWidget.svelte index 957de2296..1434abe33 100644 --- a/js/src/lib/components/InferenceWidget/widgets/ConversationalWidget/ConversationalWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/ConversationalWidget/ConversationalWidget.svelte @@ -48,7 +48,11 @@ let outputJson: string; let text = ""; - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { const trimmedText = text.trim(); if (!trimmedText) { @@ -143,7 +147,8 @@ if (opts.isPreview) { return; } - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } diff --git a/js/src/lib/components/InferenceWidget/widgets/FeatureExtractionWidget/FeatureExtractionWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/FeatureExtractionWidget/FeatureExtractionWidget.svelte index 44dc5f13b..05a7357f6 100644 --- a/js/src/lib/components/InferenceWidget/widgets/FeatureExtractionWidget/FeatureExtractionWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/FeatureExtractionWidget/FeatureExtractionWidget.svelte @@ -28,12 +28,16 @@ let outputJson: string; let text = ""; - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { const trimmedText = text.trim(); if (!trimmedText) { error = "You need to input some text"; - output = undefined; + exampleOutput = undefined; outputJson = ""; return; } @@ -63,7 +67,7 @@ computeTime = ""; error = ""; modelLoading = { isLoading: false, estimatedTime: 0 }; - output = undefined; + exampleOutput = undefined; outputJson = ""; if (res.status === "success") { @@ -111,7 +115,8 @@ if (opts.isPreview) { return; } - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } diff --git a/js/src/lib/components/InferenceWidget/widgets/FillMaskWidget/FillMaskWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/FillMaskWidget/FillMaskWidget.svelte index b02c69320..ae6542dc5 100644 --- a/js/src/lib/components/InferenceWidget/widgets/FillMaskWidget/FillMaskWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/FillMaskWidget/FillMaskWidget.svelte @@ -30,7 +30,17 @@ let text = ""; let setTextAreaValue: (text: string) => void; - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { + if (exampleOutput) { + output = exampleOutput; + outputJson = ""; + return; + } + const trimmedText = text.trim(); if (!trimmedText) { @@ -112,7 +122,8 @@ } return; } - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } function validateExample(sample: WidgetExample): sample is WidgetExampleTextInput { diff --git a/js/src/lib/components/InferenceWidget/widgets/ImageClassificationWidget/ImageClassificationWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/ImageClassificationWidget/ImageClassificationWidget.svelte index 6f0d56247..720fc0098 100644 --- a/js/src/lib/components/InferenceWidget/widgets/ImageClassificationWidget/ImageClassificationWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/ImageClassificationWidget/ImageClassificationWidget.svelte @@ -36,7 +36,7 @@ async function getOutput( file: File | Blob, - { withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {} + { withModelLoading = false, isOnLoadCall = false, exampleOutput = undefined }: InferenceRunOpts = {} ) { if (!file) { return; @@ -108,7 +108,8 @@ return; } const blob = await getBlobFromUrl(imgSrc); - getOutput(blob, opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput(blob, { ...opts.inferenceOpts, exampleOutput }); } function validateExample(sample: WidgetExample): sample is WidgetExampleAssetInput { diff --git a/js/src/lib/components/InferenceWidget/widgets/ImageSegmentationWidget/ImageSegmentationWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/ImageSegmentationWidget/ImageSegmentationWidget.svelte index 91e447ba1..b757567b1 100644 --- a/js/src/lib/components/InferenceWidget/widgets/ImageSegmentationWidget/ImageSegmentationWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/ImageSegmentationWidget/ImageSegmentationWidget.svelte @@ -52,7 +52,7 @@ async function getOutput( file: File | Blob, - { withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {} + { withModelLoading = false, isOnLoadCall = false, exampleOutput = undefined }: InferenceRunOpts = {} ) { if (!file) { return; @@ -217,7 +217,8 @@ return; } const blob = await getBlobFromUrl(imgSrc); - getOutput(blob, opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput(blob, { ...opts.inferenceOpts, exampleOutput }); } onMount(() => { diff --git a/js/src/lib/components/InferenceWidget/widgets/ImageToImageWidget/ImageToImageWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/ImageToImageWidget/ImageToImageWidget.svelte index ba84ed9c3..d6930d1c0 100644 --- a/js/src/lib/components/InferenceWidget/widgets/ImageToImageWidget/ImageToImageWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/ImageToImageWidget/ImageToImageWidget.svelte @@ -70,10 +70,15 @@ const res = await fetch(imgSrc); const blob = await res.blob(); await updateImageBase64(blob); - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { const trimmedPrompt = prompt.trim(); if (!imageBase64) { diff --git a/js/src/lib/components/InferenceWidget/widgets/ImageToTextWidget/ImageToTextWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/ImageToTextWidget/ImageToTextWidget.svelte index 03bdb352f..25671e855 100644 --- a/js/src/lib/components/InferenceWidget/widgets/ImageToTextWidget/ImageToTextWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/ImageToTextWidget/ImageToTextWidget.svelte @@ -35,7 +35,7 @@ async function getOutput( file: File | Blob, - { withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {} + { withModelLoading = false, isOnLoadCall = false, exampleOutput = undefined }: InferenceRunOpts = {} ) { if (!file) { return; diff --git a/js/src/lib/components/InferenceWidget/widgets/ObjectDetectionWidget/ObjectDetectionWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/ObjectDetectionWidget/ObjectDetectionWidget.svelte index ff818719b..c4afc41e7 100644 --- a/js/src/lib/components/InferenceWidget/widgets/ObjectDetectionWidget/ObjectDetectionWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/ObjectDetectionWidget/ObjectDetectionWidget.svelte @@ -40,7 +40,7 @@ async function getOutput( file: File | Blob, - { withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {} + { withModelLoading = false, isOnLoadCall = false, exampleOutput = undefined }: InferenceRunOpts = {} ) { if (!file) { return; @@ -136,7 +136,8 @@ return; } const blob = await getBlobFromUrl(imgSrc); - getOutput(blob, opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput(blob, { ...opts.inferenceOpts, exampleOutput }); } diff --git a/js/src/lib/components/InferenceWidget/widgets/QuestionAnsweringWidget/QuestionAnsweringWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/QuestionAnsweringWidget/QuestionAnsweringWidget.svelte index 6de4507f2..64a1c0055 100644 --- a/js/src/lib/components/InferenceWidget/widgets/QuestionAnsweringWidget/QuestionAnsweringWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/QuestionAnsweringWidget/QuestionAnsweringWidget.svelte @@ -34,7 +34,11 @@ let question = ""; let setTextAreaValue: (text: string) => void; - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { const trimmedQuestion = question.trim(); const trimmedContext = context.trim(); @@ -113,7 +117,8 @@ if (opts.isPreview) { return; } - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } function validateExample( diff --git a/js/src/lib/components/InferenceWidget/widgets/SentenceSimilarityWidget/SentenceSimilarityWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/SentenceSimilarityWidget/SentenceSimilarityWidget.svelte index 6df0cd43f..8f177c8d4 100644 --- a/js/src/lib/components/InferenceWidget/widgets/SentenceSimilarityWidget/SentenceSimilarityWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/SentenceSimilarityWidget/SentenceSimilarityWidget.svelte @@ -31,7 +31,11 @@ let output: Array<{ label: string; score: number }> = []; let outputJson: string; - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { const trimmedSourceSentence = sourceSentence.trim(); if (!trimmedSourceSentence) { error = "You need to input some text"; @@ -124,7 +128,8 @@ if (opts.isPreview) { return; } - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } diff --git a/js/src/lib/components/InferenceWidget/widgets/SummarizationWidget/SummarizationWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/SummarizationWidget/SummarizationWidget.svelte index 10cd81892..a755f66c2 100644 --- a/js/src/lib/components/InferenceWidget/widgets/SummarizationWidget/SummarizationWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/SummarizationWidget/SummarizationWidget.svelte @@ -29,7 +29,11 @@ let text = ""; let setTextAreaValue: (text: string) => void; - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { const trimmedValue = text.trim(); if (!trimmedValue) { @@ -94,7 +98,8 @@ if (opts.isPreview) { return; } - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } diff --git a/js/src/lib/components/InferenceWidget/widgets/TableQuestionAnsweringWidget/TableQuestionAnsweringWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/TableQuestionAnsweringWidget/TableQuestionAnsweringWidget.svelte index 670f71acb..aaee40402 100644 --- a/js/src/lib/components/InferenceWidget/widgets/TableQuestionAnsweringWidget/TableQuestionAnsweringWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/TableQuestionAnsweringWidget/TableQuestionAnsweringWidget.svelte @@ -54,7 +54,11 @@ table = updatedTable; } - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { const trimmedQuery = query.trim(); if (!trimmedQuery) { @@ -144,7 +148,8 @@ if (opts.isPreview) { return; } - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } diff --git a/js/src/lib/components/InferenceWidget/widgets/TabularDataWidget/TabularDataWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/TabularDataWidget/TabularDataWidget.svelte index 53e0aa5c3..40095059c 100644 --- a/js/src/lib/components/InferenceWidget/widgets/TabularDataWidget/TabularDataWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/TabularDataWidget/TabularDataWidget.svelte @@ -66,7 +66,11 @@ output = []; } - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { for (let [i, row] of table.entries()) { for (const [j, cell] of row.entries()) { if (!String(cell)) { @@ -168,7 +172,8 @@ if (opts.isPreview) { return; } - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } diff --git a/js/src/lib/components/InferenceWidget/widgets/TextGenerationWidget/TextGenerationWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/TextGenerationWidget/TextGenerationWidget.svelte index f5e0accbf..26e775f41 100644 --- a/js/src/lib/components/InferenceWidget/widgets/TextGenerationWidget/TextGenerationWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/TextGenerationWidget/TextGenerationWidget.svelte @@ -47,11 +47,23 @@ model.pipeline_tag as PipelineType ); - async function getOutput({ withModelLoading = false, isOnLoadCall = false, useCache = true }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + useCache = true, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { if (isBloomLoginRequired) { return; } + if (exampleOutput) { + output = exampleOutput.text; + outputJson = ""; + renderTypingEffect(output); + return; + } + const trimmedValue = text.trim(); if (!trimmedValue) { @@ -162,7 +174,8 @@ } return; } - getOutput({ useCache, ...opts.inferenceOpts }); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, useCache, exampleOutput }); } function validateExample(sample: WidgetExample): sample is WidgetExampleTextInput { diff --git a/js/src/lib/components/InferenceWidget/widgets/TextToImageWidget/TextToImageWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/TextToImageWidget/TextToImageWidget.svelte index 9ca830adb..fb6b05918 100644 --- a/js/src/lib/components/InferenceWidget/widgets/TextToImageWidget/TextToImageWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/TextToImageWidget/TextToImageWidget.svelte @@ -1,5 +1,5 @@ diff --git a/js/src/lib/components/InferenceWidget/widgets/TokenClassificationWidget/TokenClassificationWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/TokenClassificationWidget/TokenClassificationWidget.svelte index 90987bf8e..5d81d08af 100644 --- a/js/src/lib/components/InferenceWidget/widgets/TokenClassificationWidget/TokenClassificationWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/TokenClassificationWidget/TokenClassificationWidget.svelte @@ -46,7 +46,11 @@ let warning: string = ""; let setTextAreaValue: (text: string) => void; - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { const trimmedText = text.trim(); if (!trimmedText) { @@ -213,7 +217,8 @@ if (opts.isPreview) { return; } - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } diff --git a/js/src/lib/components/InferenceWidget/widgets/VisualQuestionAnsweringWidget/VisualQuestionAnsweringWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/VisualQuestionAnsweringWidget/VisualQuestionAnsweringWidget.svelte index ca263d93d..f3de07740 100644 --- a/js/src/lib/components/InferenceWidget/widgets/VisualQuestionAnsweringWidget/VisualQuestionAnsweringWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/VisualQuestionAnsweringWidget/VisualQuestionAnsweringWidget.svelte @@ -74,10 +74,15 @@ const res = await fetch(imgSrc); const blob = await res.blob(); await updateImageBase64(blob); - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { const trimmedQuestion = question.trim(); if (!trimmedQuestion) { diff --git a/js/src/lib/components/InferenceWidget/widgets/ZeroShotImageClassificationWidget/ZeroShotImageClassificationWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/ZeroShotImageClassificationWidget/ZeroShotImageClassificationWidget.svelte index ed2f1f07d..81d9e5fc0 100644 --- a/js/src/lib/components/InferenceWidget/widgets/ZeroShotImageClassificationWidget/ZeroShotImageClassificationWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/ZeroShotImageClassificationWidget/ZeroShotImageClassificationWidget.svelte @@ -77,10 +77,15 @@ const res = await fetch(imgSrc); const blob = await res.blob(); await updateImageBase64(blob); - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); } - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { const trimmedCandidateLabels = candidateLabels.trim().split(",").join(","); if (!trimmedCandidateLabels) { diff --git a/js/src/lib/components/InferenceWidget/widgets/ZeroShowClassificationWidget/ZeroShotClassificationWidget.svelte b/js/src/lib/components/InferenceWidget/widgets/ZeroShowClassificationWidget/ZeroShotClassificationWidget.svelte index 4e2ea5f9d..f6438eeef 100644 --- a/js/src/lib/components/InferenceWidget/widgets/ZeroShowClassificationWidget/ZeroShotClassificationWidget.svelte +++ b/js/src/lib/components/InferenceWidget/widgets/ZeroShowClassificationWidget/ZeroShotClassificationWidget.svelte @@ -34,7 +34,11 @@ let warning: string = ""; let setTextAreaValue: (text: string) => void; - async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { + async function getOutput({ + withModelLoading = false, + isOnLoadCall = false, + exampleOutput = undefined, + }: InferenceRunOpts = {}) { const trimmedText = text.trim(); const trimmedCandidateLabels = candidateLabels.trim().split(",").join(","); @@ -128,7 +132,8 @@ if (opts.isPreview) { return; } - getOutput(opts.inferenceOpts); + const exampleOutput = sample.output; + getOutput({ ...opts.inferenceOpts, exampleOutput }); }