Skip to content

Commit

Permalink
Proposal: widget types (#980)
Browse files Browse the repository at this point in the history
* wip

* 🩹 Fix types in widgets

* 🩹 Minimize diff

* 💄 Lint

* Update js/src/lib/components/InferenceWidget/shared/WidgetInputSamples/WidgetInputSamples.svelte

Co-authored-by: Julien Chaumond <[email protected]>

* 🧹 cleanup unwanted change

* 🩹 wip: type safety

* ✨ Typing & validation for ALL widgets

* ♻️ Slightly less verbose version

* 🩹

* 🩹 Import type?

* 🩹 Alternative generic syntax to satisfy svelte-check

---------

Co-authored-by: Julien Chaumond <[email protected]>
  • Loading branch information
SBrandeis and julien-c authored Sep 29, 2023
1 parent b6b21c8 commit 5b91fc5
Show file tree
Hide file tree
Showing 27 changed files with 310 additions and 122 deletions.
89 changes: 42 additions & 47 deletions js/src/lib/components/InferenceWidget/shared/WidgetExample.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,83 +12,78 @@ export interface WidgetExampleOutputText {
export interface WidgetExampleOutputUrl {
url: string;
}

export type WidgetExampleOutput =
| WidgetExampleOutputLabels
| WidgetExampleOutputAnswerScore
| WidgetExampleOutputText
| WidgetExampleOutputUrl;
//#endregion

export interface WidgetExampleBase {
export interface WidgetExampleBase<TOutput> {
example_title?: string;
group?: string;
output?: TOutput;
}

export interface WidgetExampleTextInputLabelsOutput extends WidgetExampleBase {
text: string;
output?: WidgetExampleOutputLabels;
export interface WidgetExampleTextInput<TOutput = WidgetExampleOutput> extends WidgetExampleBase<TOutput> {
text: string;
}

export interface WidgetExampleTextAndContextInputAnswerScoreOutput extends WidgetExampleBase {
text: string;
export interface WidgetExampleTextAndContextInput<TOutput = WidgetExampleOutput>
extends WidgetExampleTextInput<TOutput> {
context: string;
output?: WidgetExampleOutputAnswerScore;
}

export interface WidgetExampleTextInputTextOutput extends WidgetExampleBase {
text: string;
output?: WidgetExampleOutputText;
export interface WidgetExampleTextAndTableInput<TOutput = WidgetExampleOutput> extends WidgetExampleTextInput<TOutput> {
table: (string | number)[][];
}

export interface WidgetExampleTextInputUrlOutput extends WidgetExampleBase {
text: string;
output?: WidgetExampleOutputUrl;
export interface WidgetExampleAssetInput<TOutput = WidgetExampleOutput> extends WidgetExampleBase<TOutput> {
src: string;
}

export interface WidgetExampleAssetInputLabelsOutput extends WidgetExampleBase {
src: string;
output?: WidgetExampleOutputLabels;
export interface WidgetExampleAssetAndPromptInput<TOutput = WidgetExampleOutput>
extends WidgetExampleAssetInput<TOutput> {
prompt: string;
}

export interface WidgetExampleAssetInputTextOutput extends WidgetExampleBase {
src: string;
output?: WidgetExampleOutputText;
}
export type WidgetExampleAssetAndTextInput<TOutput = WidgetExampleOutput> = WidgetExampleAssetInput<TOutput> &
WidgetExampleTextInput<TOutput>;

export interface WidgetExampleAssetInputUrlOutput extends WidgetExampleBase {
src: string;
output?: WidgetExampleOutputUrl;
}
export type WidgetExampleAssetAndZeroShotInput<TOutput = WidgetExampleOutput> = WidgetExampleAssetInput<TOutput> &
WidgetExampleZeroShotTextInput<TOutput>;

//#region more exotic stuff
export interface WidgetExampleStructuredDataInputLabelsOutput extends WidgetExampleBase {
export interface WidgetExampleStructuredDataInput<TOutput = WidgetExampleOutput> extends WidgetExampleBase<TOutput> {
structuredData: TableData;
output?: WidgetExampleOutputLabels;
}

export interface WidgetExampleTableDataInputLabelsOutput extends WidgetExampleBase {
table: TableData;
output?: WidgetExampleOutputLabels;
export interface WidgetExampleTableDataInput<TOutput = WidgetExampleOutput> extends WidgetExampleBase<TOutput> {
table: TableData;
}

export interface WidgetExampleZeroShotTextInputLabelsOutput extends WidgetExampleBase {
export interface WidgetExampleZeroShotTextInput<TOutput = WidgetExampleOutput> extends WidgetExampleTextInput<TOutput> {
text: string;
candidate_labels: string;
multi_class: boolean;
output?: WidgetExampleOutputLabels;
}

export interface WidgetExampleSentenceSimilarityInputLabelsOutput extends WidgetExampleBase {
export interface WidgetExampleSentenceSimilarityInput<TOutput = WidgetExampleOutput>
extends WidgetExampleBase<TOutput> {
source_sentence: string;
sentences: string[];
output?: WidgetExampleOutputLabels;
}

//#endregion

export type WidgetExample =
| WidgetExampleTextInputLabelsOutput
| WidgetExampleTextAndContextInputAnswerScoreOutput
| WidgetExampleTextInputTextOutput
| WidgetExampleTextInputUrlOutput
| WidgetExampleAssetInputLabelsOutput
| WidgetExampleAssetInputTextOutput
| WidgetExampleAssetInputUrlOutput
| WidgetExampleStructuredDataInputLabelsOutput
| WidgetExampleTableDataInputLabelsOutput
| WidgetExampleZeroShotTextInputLabelsOutput
| WidgetExampleSentenceSimilarityInputLabelsOutput;
export type WidgetExample<TOutput = WidgetExampleOutput> =
| WidgetExampleTextInput<TOutput>
| WidgetExampleTextAndContextInput<TOutput>
| WidgetExampleTextAndTableInput<TOutput>
| WidgetExampleAssetInput<TOutput>
| WidgetExampleAssetAndPromptInput<TOutput>
| WidgetExampleAssetAndTextInput<TOutput>
| WidgetExampleAssetAndZeroShotInput<TOutput>
| WidgetExampleStructuredDataInput<TOutput>
| WidgetExampleTableDataInput<TOutput>
| WidgetExampleZeroShotTextInput<TOutput>
| WidgetExampleSentenceSimilarityInput<TOutput>;
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
<script lang="ts" generics="T extends WidgetExample">
<script lang="ts">
import type { WidgetExample } from "../WidgetExample";
import { slide } from "svelte/transition";
import IconCaretDownV2 from "../../../Icons/IconCaretDownV2.svelte";
type TWidgetExample = $$Generic<WidgetExample>;
export let classNames = "";
export let isLoading = false;
export let inputSamples: WidgetExample[];
export let applyInputSample: (sample: T) => void;
export let previewInputSample: (sample: T) => void;
export let inputSamples: TWidgetExample[];
export let applyInputSample: (sample: TWidgetExample) => void;
export let previewInputSample: (sample: TWidgetExample) => void;
let containerEl: HTMLElement;
let isOptionsVisible = false;
Expand All @@ -25,12 +27,12 @@
hideOptions();
const sample = inputSamples[idx];
title = sample.example_title as string;
applyInputSample(sample as T);
applyInputSample(sample);
}
function _previewInputSample(idx: number) {
const sample = inputSamples[idx];
previewInputSample(sample as T);
previewInputSample(sample);
}
function toggleOptionsVisibility() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<script lang="ts" generics="T extends WidgetExample">
<script lang="ts" generics="TWidgetExample extends WidgetExample">
import type { WidgetProps, ModelLoadInfo } from "../types";
import type { WidgetExample } from "../WidgetExample";
Expand Down Expand Up @@ -26,22 +26,27 @@
};
export let noTitle = false;
export let outputJson: string;
export let applyInputSample: (sample: T) => void = () => {};
export let previewInputSample: (sample: T) => void = () => {};
export let applyInputSample: (sample: TWidgetExample) => void = () => {};
export let previewInputSample: (sample: TWidgetExample) => void = () => {};
export let validateExample: (sample: WidgetExample) => sample is TWidgetExample;
let isMaximized = false;
let modelLoadInfo: ModelLoadInfo | undefined = undefined;
let selectedInputGroup: string;
const inputSamples: WidgetExample[] = (model.widgetData ?? [])
const inputSamples = (model.widgetData ?? [])
.filter(validateExample)
.sort((sample1, sample2) => (sample2.example_title ? 1 : 0) - (sample1.example_title ? 1 : 0))
.map((sample, idx) => ({
example_title: `Example ${++idx}`,
group: "Group 1",
...sample,
}));
const inputGroups: { group: string; inputSamples: WidgetExample[] }[] = [];
const inputGroups: {
group: string;
inputSamples: TWidgetExample[];
}[] = [];
for (const inputSample of inputSamples) {
const isExist = inputGroups.find(({ group }) => group === inputSample.group);
if (!isExist) {
Expand Down Expand Up @@ -80,7 +85,7 @@
</p>
{:else}
{#if isMaximized}
<button class="absolute top-6 right-12" on:click={onClickMaximizeBtn}>
<button class="absolute right-12 top-6" on:click={onClickMaximizeBtn}>
<IconCross classNames="text-xl text-gray-500 hover:text-black" />
</button>
{/if}
Expand Down
87 changes: 87 additions & 0 deletions js/src/lib/components/InferenceWidget/shared/inputValidation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import type {
WidgetExample,
WidgetExampleAssetAndPromptInput,
WidgetExampleAssetAndTextInput,
WidgetExampleAssetAndZeroShotInput,
WidgetExampleAssetInput,
WidgetExampleSentenceSimilarityInput,
WidgetExampleStructuredDataInput,
WidgetExampleTableDataInput,
WidgetExampleTextAndContextInput,
WidgetExampleTextAndTableInput,
WidgetExampleTextInput,
WidgetExampleZeroShotTextInput,
} from "./WidgetExample";

export function isTextInput<TOutput>(sample: WidgetExample<TOutput>): sample is WidgetExampleTextInput<TOutput> {
return "text" in sample;
}

export function isTextAndContextInput<TOutput>(
sample: WidgetExample<TOutput>
): sample is WidgetExampleTextAndContextInput<TOutput> {
return isTextInput(sample) && "context" in sample;
}

export function isAssetInput<TOutput>(sample: WidgetExample<TOutput>): sample is WidgetExampleAssetInput<TOutput> {
return "src" in sample;
}

export function isAssetAndPromptInput<TOutput>(
sample: WidgetExample<TOutput>
): sample is WidgetExampleAssetAndPromptInput<TOutput> {
return isAssetInput(sample) && "prompt" in sample && typeof sample.prompt === "string";
}

export function isAssetAndTextInput<TOutput>(
sample: WidgetExample<TOutput>
): sample is WidgetExampleAssetAndTextInput<TOutput> {
return isAssetInput(sample) && isTextInput(sample);
}

export function isStructuredDataInput<TOutput>(
sample: WidgetExample<TOutput>
): sample is WidgetExampleStructuredDataInput<TOutput> {
return "structuredData" in sample;
}

export function isTableDataInput<TOutput>(
sample: WidgetExample<TOutput>
): sample is WidgetExampleTableDataInput<TOutput> {
return "table" in sample;
}

function _isZeroShotTextInput<TOutput>(
sample: WidgetExample<TOutput>
): sample is Exclude<WidgetExampleZeroShotTextInput<TOutput>, "text"> {
return "candidate_labels" in sample && "multi_class" in sample;
}

export function isZeroShotTextInput<TOutput>(
sample: WidgetExample<TOutput>
): sample is WidgetExampleZeroShotTextInput<TOutput> {
return isTextInput(sample) && _isZeroShotTextInput(sample);
}

export function isSentenceSimilarityInput<TOutput>(
sample: WidgetExample<TOutput>
): sample is WidgetExampleSentenceSimilarityInput<TOutput> {
return "source_sentence" in sample && "sentences" in sample;
}

export function isTextAndTableInput<TOutput>(
sample: WidgetExample<TOutput>
): sample is WidgetExampleTextAndTableInput<TOutput> {
return (
isTextInput(sample)
&& "table" in sample
&& Array.isArray(sample.table)
&& sample.table.every(r => Array.isArray(r) && r.every(c => typeof c === "string" || typeof c === "number"))
);
}

export function isAssetAndZeroShotInput<TOutput>(
sample: WidgetExample<TOutput>
): sample is WidgetExampleAssetAndZeroShotInput<TOutput> {
return isAssetInput(sample) && _isZeroShotTextInput(sample);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<script lang="ts">
import type { WidgetProps } from "../../shared/types";
import type { WidgetExampleAssetInputLabelsOutput } from "../../shared/WidgetExample";
import type { WidgetExample, WidgetExampleAssetInput, WidgetExampleOutputLabels } from "../../shared/WidgetExample";
import { onMount } from "svelte";
Expand All @@ -12,6 +12,7 @@
import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapper.svelte";
import { getResponse, getBlobFromUrl, getDemoInputs } from "../../shared/helpers";
import { isValidOutputLabels } from "../../shared/outputValidation";
import { isAssetInput } from "../../shared/inputValidation";
export let apiToken: WidgetProps["apiToken"];
export let apiUrl: WidgetProps["apiUrl"];
Expand Down Expand Up @@ -121,15 +122,15 @@
throw new TypeError("Invalid output: output must be of type Array<label: string, score:number>");
}
function applyInputSample(sample: WidgetExampleAssetInputLabelsOutput) {
function applyInputSample(sample: WidgetExampleAssetInput<WidgetExampleOutputLabels>) {
file = null;
filename = sample.example_title!;
fileUrl = sample.src;
selectedSampleUrl = sample.src;
getOutput();
}
function previewInputSample(sample: WidgetExampleAssetInputLabelsOutput) {
function previewInputSample(sample: WidgetExampleAssetInput<WidgetExampleOutputLabels>) {
filename = sample.example_title!;
fileUrl = sample.src;
if (isValidOutputLabels(sample.output)) {
Expand All @@ -141,6 +142,10 @@
}
}
function validateExample(sample: WidgetExample): sample is WidgetExampleAssetInput<WidgetExampleOutputLabels> {
return isAssetInput(sample) && (!sample.output || isValidOutputLabels(sample.output));
}
onMount(() => {
const [exampleTitle, src] = getDemoInputs(model, ["example_title", "src"]);
if (callApiOnMount && src) {
Expand All @@ -164,12 +169,13 @@
{noTitle}
{outputJson}
{previewInputSample}
{validateExample}
>
<svelte:fragment slot="top">
<form>
<div class="flex flex-wrap items-center">
<WidgetFileInput accept="audio/*" classNames="mt-1.5 mr-2" {onSelectFile} />
<span class="mt-1.5 mr-2">or</span>
<span class="mr-2 mt-1.5">or</span>
<WidgetRecorder classNames="mt-1.5" {onRecordStart} onRecordStop={onSelectFile} onError={onRecordError} />
</div>
{#if fileUrl}
Expand Down
Loading

0 comments on commit 5b91fc5

Please sign in to comment.