Skip to content

Commit

Permalink
Update snippets fn typing
Browse files Browse the repository at this point in the history
  • Loading branch information
beurkinger committed Oct 17, 2023
1 parent 63221a0 commit f5b2f8a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 23 deletions.
8 changes: 5 additions & 3 deletions js/src/lib/inferenceSnippets/inputs.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import type { PipelineType, ModelData } from "../interfaces/Types";

type ModelPartial = Pick<ModelData, 'id' | 'pipeline_tag' | 'widgetData'>;

const inputsZeroShotClassification = () =>
`"Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"`;

Expand Down Expand Up @@ -44,7 +46,7 @@ const inputsTextGeneration = () => `"Can you please let us know more details abo

const inputsText2TextGeneration = () => `"The answer to the universe is"`;

const inputsFillMask = (model: ModelData) => `"The answer to the universe is ${model.mask_token}."`;
const inputsFillMask = (model: ModelPartial) => `"The answer to the universe is ${model.mask_token}."`;

const inputsSentenceSimilarity = () =>
`{
Expand Down Expand Up @@ -77,7 +79,7 @@ const inputsTextToSpeech = () => `"The answer to the universe is 42"`;
const inputsAutomaticSpeechRecognition = () => `"sample1.flac"`;

const modelInputSnippets: {
[key in PipelineType]?: (model: ModelData) => string;
[key in PipelineType]?: (model: ModelPartial) => string;
} = {
"audio-to-audio": inputsAudioToAudio,
"audio-classification": inputsAudioClassification,
Expand Down Expand Up @@ -105,7 +107,7 @@ const modelInputSnippets: {

// Use noWrap to put the whole snippet on a single line (removing new lines and tabulations)
// Use noQuotes to strip quotes from start & end (example: "abc" -> abc)
export function getModelInputSnippet(model: ModelData, noWrap = false, noQuotes = false): string {
export function getModelInputSnippet(model: ModelPartial, noWrap = false, noQuotes = false): string {
if (model.pipeline_tag) {
const inputs = modelInputSnippets[model.pipeline_tag];
if (inputs) {
Expand Down
14 changes: 8 additions & 6 deletions js/src/lib/inferenceSnippets/serveCurl.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
import type { PipelineType, ModelData } from "../interfaces/Types";
import { getModelInputSnippet } from "./inputs";

export const snippetBasic = (model: ModelData, accessToken: string): string =>
type ModelPartial = Pick<ModelData, 'id' | 'pipeline_tag' | 'widgetData'>;

export const snippetBasic = (model: ModelPartial, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}}' \\
-H 'Content-Type: application/json' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const snippetZeroShotClassification = (model: ModelData, accessToken: string): string =>
export const snippetZeroShotClassification = (model: ModelPartial, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
-H 'Content-Type: application/json' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const snippetFile = (model: ModelData, accessToken: string): string =>
export const snippetFile = (model: ModelPartial, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
--data-binary '@${getModelInputSnippet(model, true, true)}' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const curlSnippets: Partial<Record<PipelineType, (model: ModelData, accessToken: string) => string>> = {
export const curlSnippets: Partial<Record<PipelineType, (model: ModelPartial, accessToken: string) => string>> = {
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": snippetBasic,
"token-classification": snippetBasic,
Expand All @@ -50,12 +52,12 @@ export const curlSnippets: Partial<Record<PipelineType, (model: ModelData, acces
"image-segmentation": snippetFile,
};

export function getCurlInferenceSnippet(model: ModelData, accessToken: string): string {
export function getCurlInferenceSnippet(model: ModelPartial, accessToken: string): string {
return model.pipeline_tag && model.pipeline_tag in curlSnippets
? curlSnippets[model.pipeline_tag]?.(model, accessToken) ?? ""
: "";
}

export function hasCurlInferenceSnippet(model: ModelData): boolean {
export function hasCurlInferenceSnippet(model: ModelPartial): boolean {
return !!model.pipeline_tag && model.pipeline_tag in curlSnippets;
}
16 changes: 9 additions & 7 deletions js/src/lib/inferenceSnippets/serveJs.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import type { PipelineType, ModelData } from "../interfaces/Types";
import { getModelInputSnippet } from "./inputs";

export const snippetBasic = (model: ModelData, accessToken: string): string =>
type ModelPartial = Pick<ModelData, 'id' | 'pipeline_tag' | 'widgetData'>;

export const snippetBasic = (model: ModelPartial, accessToken: string): string =>
`async function query(data) {
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
Expand All @@ -19,7 +21,7 @@ query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
console.log(JSON.stringify(response));
});`;

export const snippetZeroShotClassification = (model: ModelData, accessToken: string): string =>
export const snippetZeroShotClassification = (model: ModelPartial, accessToken: string): string =>
`async function query(data) {
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
Expand All @@ -39,7 +41,7 @@ query({"inputs": ${getModelInputSnippet(
console.log(JSON.stringify(response));
});`;

export const snippetTextToImage = (model: ModelData, accessToken: string): string =>
export const snippetTextToImage = (model: ModelPartial, accessToken: string): string =>
`async function query(data) {
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
Expand All @@ -56,7 +58,7 @@ query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
// Use image
});`;

export const snippetFile = (model: ModelData, accessToken: string): string =>
export const snippetFile = (model: ModelPartial, accessToken: string): string =>
`async function query(filename) {
const data = fs.readFileSync(filename);
const response = await fetch(
Expand All @@ -75,7 +77,7 @@ query(${getModelInputSnippet(model)}).then((response) => {
console.log(JSON.stringify(response));
});`;

export const jsSnippets: Partial<Record<PipelineType, (model: ModelData, accessToken: string) => string>> = {
export const jsSnippets: Partial<Record<PipelineType, (model: ModelPartial, accessToken: string) => string>> = {
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": snippetBasic,
"token-classification": snippetBasic,
Expand All @@ -101,12 +103,12 @@ export const jsSnippets: Partial<Record<PipelineType, (model: ModelData, accessT
"image-segmentation": snippetFile,
};

export function getJsInferenceSnippet(model: ModelData, accessToken: string): string {
export function getJsInferenceSnippet(model: ModelPartial, accessToken: string): string {
return model.pipeline_tag && model.pipeline_tag in jsSnippets
? jsSnippets[model.pipeline_tag]?.(model, accessToken) ?? ""
: "";
}

export function hasJsInferenceSnippet(model: ModelData): boolean {
export function hasJsInferenceSnippet(model: ModelPartial): boolean {
return !!model.pipeline_tag && model.pipeline_tag in jsSnippets;
}
16 changes: 9 additions & 7 deletions js/src/lib/inferenceSnippets/servePython.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import type { PipelineType, ModelData } from "../interfaces/Types";
import { getModelInputSnippet } from "./inputs";

export const snippetZeroShotClassification = (model: ModelData): string =>
type ModelPartial = Pick<ModelData, 'id' | 'pipeline_tag' | 'widgetData'>;

export const snippetZeroShotClassification = (model: ModelPartial): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
Expand All @@ -11,7 +13,7 @@ output = query({
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
})`;

export const snippetBasic = (model: ModelData): string =>
export const snippetBasic = (model: ModelPartial): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
Expand All @@ -20,7 +22,7 @@ output = query({
"inputs": ${getModelInputSnippet(model)},
})`;

export const snippetFile = (model: ModelData): string =>
export const snippetFile = (model: ModelPartial): string =>
`def query(filename):
with open(filename, "rb") as f:
data = f.read()
Expand All @@ -29,7 +31,7 @@ export const snippetFile = (model: ModelData): string =>
output = query(${getModelInputSnippet(model)})`;

export const snippetTextToImage = (model: ModelData): string =>
export const snippetTextToImage = (model: ModelPartial): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
Expand All @@ -41,7 +43,7 @@ import io
from PIL import Image
image = Image.open(io.BytesIO(image_bytes))`;

export const pythonSnippets: Partial<Record<PipelineType, (model: ModelData) => string>> = {
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelPartial) => string>> = {
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": snippetBasic,
"token-classification": snippetBasic,
Expand All @@ -67,7 +69,7 @@ export const pythonSnippets: Partial<Record<PipelineType, (model: ModelData) =>
"image-segmentation": snippetFile,
};

export function getPythonInferenceSnippet(model: ModelData, accessToken: string): string {
export function getPythonInferenceSnippet(model: ModelPartial, accessToken: string): string {
const body =
model.pipeline_tag && model.pipeline_tag in pythonSnippets ? pythonSnippets[model.pipeline_tag]?.(model) ?? "" : "";

Expand All @@ -79,6 +81,6 @@ headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Beare
${body}`;
}

export function hasPythonInferenceSnippet(model: ModelData): boolean {
export function hasPythonInferenceSnippet(model: ModelPartial): boolean {
return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets;
}

0 comments on commit f5b2f8a

Please sign in to comment.