From 335b817e9c75782c794c8e72dbe4d7f763794baf Mon Sep 17 00:00:00 2001 From: Daniel Rochetti Date: Wed, 6 Mar 2024 08:44:04 -0800 Subject: [PATCH] feat(client): add streaming support (#56) * feat(client): add streaming support * fix: variable type definition * chore: add docs * chore: bump to client 0.9.0 --- .../app/streaming/page.tsx | 87 ++++++ libs/client/package.json | 3 +- libs/client/src/auth.ts | 26 ++ libs/client/src/index.ts | 1 + libs/client/src/realtime.ts | 27 +- libs/client/src/streaming.ts | 251 ++++++++++++++++++ package-lock.json | 9 + package.json | 1 + 8 files changed, 379 insertions(+), 26 deletions(-) create mode 100644 apps/demo-nextjs-app-router/app/streaming/page.tsx create mode 100644 libs/client/src/auth.ts create mode 100644 libs/client/src/streaming.ts diff --git a/apps/demo-nextjs-app-router/app/streaming/page.tsx b/apps/demo-nextjs-app-router/app/streaming/page.tsx new file mode 100644 index 0000000..d0de70b --- /dev/null +++ b/apps/demo-nextjs-app-router/app/streaming/page.tsx @@ -0,0 +1,87 @@ +'use client'; + +import * as fal from '@fal-ai/serverless-client'; +import { useState } from 'react'; + +fal.config({ + proxyUrl: '/api/fal/proxy', +}); + +type LlavaInput = { + prompt: string; + image_url: string; + max_new_tokens?: number; + temperature?: number; + top_p?: number; +}; + +type LlavaOutput = { + output: string; + partial: boolean; + stats: { + num_input_tokens: number; + num_output_tokens: number; + }; +}; + +export default function StreamingDemo() { + const [answer, setAnswer] = useState(''); + const [streamStatus, setStreamStatus] = useState('idle'); + + const runInference = async () => { + const stream = await fal.stream( + 'fal-ai/llavav15-13b', + { + input: { + prompt: + 'Do you know who drew this picture and what is the name of it?', + image_url: 'https://llava-vl.github.io/static/images/monalisa.jpg', + max_new_tokens: 100, + temperature: 0.2, + top_p: 1, + }, + } + ); + setStreamStatus('running'); + + for await (const partial of stream) { + setAnswer(partial.output); + } + + const result = await stream.done(); + setStreamStatus('done'); + setAnswer(result.output); + }; + + return ( +
+
+

+ Hello fal +{' '} + streaming +

+ +
+ +
+ +
+
+

Answer

+ + streaming: {streamStatus} + +
+

+ {answer} +

+
+
+
+ ); +} diff --git a/libs/client/package.json b/libs/client/package.json index 175d82c..364dfeb 100644 --- a/libs/client/package.json +++ b/libs/client/package.json @@ -1,7 +1,7 @@ { "name": "@fal-ai/serverless-client", "description": "The fal serverless JS/TS client", - "version": "0.8.6", + "version": "0.9.0", "license": "MIT", "repository": { "type": "git", @@ -17,6 +17,7 @@ ], "dependencies": { "@msgpack/msgpack": "^3.0.0-beta2", + "eventsource-parser": "^1.1.2", "robot3": "^0.4.1", "uuid-random": "^1.3.2" }, diff --git a/libs/client/src/auth.ts b/libs/client/src/auth.ts new file mode 100644 index 0000000..6c748b3 --- /dev/null +++ b/libs/client/src/auth.ts @@ -0,0 +1,26 @@ +import { getRestApiUrl } from './config'; +import { dispatchRequest } from './request'; +import { ensureAppIdFormat } from './utils'; + +export const TOKEN_EXPIRATION_SECONDS = 120; + +/** + * Get a token to connect to the realtime endpoint. + */ +export async function getTemporaryAuthToken(app: string): Promise { + const [, appAlias] = ensureAppIdFormat(app).split('/'); + const token: string | object = await dispatchRequest( + 'POST', + `${getRestApiUrl()}/tokens/`, + { + allowed_apps: [appAlias], + token_expiration: TOKEN_EXPIRATION_SECONDS, + } + ); + // keep this in case the response was wrapped (old versions of the proxy do that) + // should be safe to remove in the future + if (typeof token !== 'string' && token['detail']) { + return token['detail']; + } + return token; +} diff --git a/libs/client/src/index.ts b/libs/client/src/index.ts index e84e921..daed06e 100644 --- a/libs/client/src/index.ts +++ b/libs/client/src/index.ts @@ -6,6 +6,7 @@ export { realtimeImpl as realtime } from './realtime'; export { ApiError, ValidationError } from './response'; export type { ResponseHandler } from './response'; export { storageImpl as storage } from './storage'; +export { stream } from './streaming'; export type { QueueStatus, ValidationErrorInfo, diff --git a/libs/client/src/realtime.ts b/libs/client/src/realtime.ts index e6351fb..9b66690 100644 --- a/libs/client/src/realtime.ts +++ b/libs/client/src/realtime.ts @@ -13,8 +13,7 @@ import { transition, } from 'robot3'; import uuid from 'uuid-random'; -import { getRestApiUrl } from './config'; -import { dispatchRequest } from './request'; +import { TOKEN_EXPIRATION_SECONDS, getTemporaryAuthToken } from './auth'; import { ApiError } from './response'; import { isBrowser } from './runtime'; import { ensureAppIdFormat, isReact, throttle } from './utils'; @@ -280,7 +279,6 @@ function buildRealtimeUrl( return `wss://fal.run/${appId}/${suffix}?${queryParams.toString()}`; } -const TOKEN_EXPIRATION_SECONDS = 120; const DEFAULT_THROTTLE_INTERVAL = 128; function shouldSendBinary(message: any): boolean { @@ -292,27 +290,6 @@ function shouldSendBinary(message: any): boolean { ); } -/** - * Get a token to connect to the realtime endpoint. - */ -async function getToken(app: string): Promise { - const [, appAlias] = ensureAppIdFormat(app).split('/'); - const token: string | object = await dispatchRequest( - 'POST', - `${getRestApiUrl()}/tokens/`, - { - allowed_apps: [appAlias], - token_expiration: TOKEN_EXPIRATION_SECONDS, - } - ); - // keep this in case the response was wrapped (old versions of the proxy do that) - // should be safe to remove in the future - if (typeof token !== 'string' && token['detail']) { - return token['detail']; - } - return token; -} - function isUnauthorizedError(message: any): boolean { // TODO we need better protocol definition with error codes return message['status'] === 'error' && message['error'] === 'Unauthorized'; @@ -441,7 +418,7 @@ export const realtimeImpl: RealtimeClient = { previousState !== machine.current ) { send({ type: 'initiateAuth' }); - getToken(app) + getTemporaryAuthToken(app) .then((token) => { send({ type: 'authenticated', token }); const tokenExpirationTimeout = Math.round( diff --git a/libs/client/src/streaming.ts b/libs/client/src/streaming.ts new file mode 100644 index 0000000..115bc96 --- /dev/null +++ b/libs/client/src/streaming.ts @@ -0,0 +1,251 @@ +import { createParser } from 'eventsource-parser'; +import { getTemporaryAuthToken } from './auth'; +import { buildUrl } from './function'; +import { ApiError, defaultResponseHandler } from './response'; +import { storageImpl } from './storage'; + +/** + * The stream API options. It requires the API input and also + * offers configuration options. + */ +type StreamOptions = { + /** + * The API input payload. + */ + input: Input; + + /** + * The maximum time interval in milliseconds between stream chunks. Defaults to 15s. + */ + timeout?: number; + + /** + * Whether it should auto-upload File-like types to fal's storage + * or not. + */ + autoUpload?: boolean; +}; + +const EVENT_STREAM_TIMEOUT = 15 * 1000; + +type FalStreamEventType = 'message' | 'error' | 'done'; + +type EventHandler = (event: any) => void; + +/** + * The class representing a streaming response. With t + */ +class FalStream { + // properties + url: string; + options: StreamOptions; + + // support for event listeners + private listeners: Map = new Map(); + private buffer: Output[] = []; + + // local state + private currentData: Output | undefined = undefined; + private lastEventTimestamp = 0; + private streamClosed = false; + private donePromise: Promise; + + constructor(url: string, options: StreamOptions) { + this.url = url; + this.options = options; + this.donePromise = new Promise((resolve, reject) => { + if (this.streamClosed) { + reject( + new ApiError({ + message: 'Streaming connection is already closed.', + status: 400, + body: undefined, + }) + ); + } + this.on('done', (data) => { + this.streamClosed = true; + resolve(data); + }); + this.on('error', (error) => { + this.streamClosed = true; + reject(error); + }); + }); + this.start(); + } + + private start = async () => { + try { + const response = await fetch(this.url, { + method: 'POST', + headers: { + accept: 'text/event-stream', + 'content-type': 'application/json', + }, + body: JSON.stringify(this.options.input), + }); + this.handleResponse(response); + } catch (error) { + this.handleError(error); + } + }; + + private handleResponse = async (response: Response) => { + if (!response.ok) { + try { + // we know the response failed, call the response handler + // so the exception gets converted to ApiError correctly + await defaultResponseHandler(response); + } catch (error) { + this.emit('error', error); + } + return; + } + + const body = response.body; + if (!body) { + this.emit( + 'error', + new ApiError({ + message: 'Response body is empty.', + status: 400, + body: undefined, + }) + ); + return; + } + const decoder = new TextDecoder('utf-8'); + const reader = response.body.getReader(); + + const parser = createParser((event) => { + if (event.type === 'event') { + const data = event.data; + + try { + const parsedData = JSON.parse(data); + this.buffer.push(parsedData); + this.currentData = parsedData; + this.emit('message', parsedData); + } catch (e) { + this.emit('error', e); + } + } + }); + + const timeout = this.options.timeout ?? EVENT_STREAM_TIMEOUT; + + const readPartialResponse = async () => { + const { value, done } = await reader.read(); + this.lastEventTimestamp = Date.now(); + + parser.feed(decoder.decode(value)); + + if (Date.now() - this.lastEventTimestamp > timeout) { + this.emit( + 'error', + new ApiError({ + message: + 'Event stream timed out after 15 seconds with no messages.', + status: 408, + }) + ); + } + + if (!done) { + readPartialResponse().catch(this.handleError); + } else { + this.emit('done', this.currentData); + } + }; + + readPartialResponse().catch(this.handleError); + return; + }; + + private handleError = (error: any) => { + const apiError = + error instanceof ApiError + ? error + : new ApiError({ + message: error.message ?? 'An unknown error occurred', + status: 500, + }); + this.emit('error', apiError); + return; + }; + + public on = (type: FalStreamEventType, listener: EventHandler) => { + if (!this.listeners.has(type)) { + this.listeners.set(type, []); + } + this.listeners.get(type)?.push(listener); + }; + + private emit = (type: FalStreamEventType, event: any) => { + const listeners = this.listeners.get(type) || []; + for (const listener of listeners) { + listener(event); + } + }; + + async *[Symbol.asyncIterator]() { + let running = true; + const stopAsyncIterator = () => (running = false); + this.on('error', stopAsyncIterator); + this.on('done', stopAsyncIterator); + while (running) { + const data = this.buffer.shift(); + if (data) { + yield data; + } + + // the short timeout ensures the while loop doesn't block other + // frames getting executed concurrently + await new Promise((resolve) => setTimeout(resolve, 16)); + } + } + + /** + * Gets a reference to the `Promise` that indicates whether the streaming + * is done or not. Developers should always call this in their apps to ensure + * the request is over. + * + * An alternative to this, is to use `on('done')` in case your application + * architecture works best with event listeners. + * + * @returns the promise that resolves when the request is done. + */ + public done = async () => this.donePromise; +} + +/** + * Calls a fal app that supports streaming and provides a streaming-capable + * object as a result, that can be used to get partial results through either + * `AsyncIterator` or through an event listener. + * + * @param appId the app id, e.g. `fal-ai/llavav15-13b`. + * @param options the request options, including the input payload. + * @returns the `FalStream` instance. + */ +export async function stream, Output = any>( + appId: string, + options: StreamOptions +): Promise> { + const token = await getTemporaryAuthToken(appId); + const url = buildUrl(appId, { path: '/stream' }); + + const input = + options.input && options.autoUpload !== false + ? await storageImpl.transformInput(options.input) + : options.input; + + const queryParams = new URLSearchParams({ + fal_jwt_token: token, + }); + + return new FalStream(`${url}?${queryParams}`, { + ...options, + input: input as Input, + }); +} diff --git a/package-lock.json b/package-lock.json index d138000..81f01c7 100644 --- a/package-lock.json +++ b/package-lock.json @@ -24,6 +24,7 @@ "cross-fetch": "^3.1.5", "dotenv": "^16.3.1", "encoding": "^0.1.13", + "eventsource-parser": "^1.1.2", "execa": "^8.0.1", "express": "^4.18.2", "fast-glob": "^3.2.12", @@ -14490,6 +14491,14 @@ "node": ">=0.8.x" } }, + "node_modules/eventsource-parser": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-1.1.2.tgz", + "integrity": "sha512-v0eOBUbiaFojBu2s2NPBfYUoRR9GjcDNvCXVaqEf5vVfpIAh9f8RCo4vXTP8c63QRKCFwoLpMpTdPwwhEKVgzA==", + "engines": { + "node": ">=14.18" + } + }, "node_modules/execa": { "version": "8.0.1", "resolved": "https://registry.npmjs.org/execa/-/execa-8.0.1.tgz", diff --git a/package.json b/package.json index 9c9f9bb..bdf6535 100644 --- a/package.json +++ b/package.json @@ -40,6 +40,7 @@ "cross-fetch": "^3.1.5", "dotenv": "^16.3.1", "encoding": "^0.1.13", + "eventsource-parser": "^1.1.2", "execa": "^8.0.1", "express": "^4.18.2", "fast-glob": "^3.2.12",