Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): binary chunk streaming support #77

Merged
merged 14 commits into from
Aug 6, 2024
9 changes: 7 additions & 2 deletions apps/demo-nextjs-app-router/app/queue/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ function Error(props: ErrorProps) {
);
}

const DEFAULT_ENDPOINT_ID = 'fal-ai/fast-sdxl';
const DEFAULT_INPUT = `{
"prompt": "A beautiful sunset over the ocean"
}`;

export default function Home() {
// Input state
const [endpointId, setEndpointId] = useState<string>('');
const [input, setInput] = useState<string>('{}');
const [endpointId, setEndpointId] = useState<string>(DEFAULT_ENDPOINT_ID);
const [input, setInput] = useState<string>(DEFAULT_INPUT);
// Result state
const [loading, setLoading] = useState(false);
const [error, setError] = useState<Error | null>(null);
Expand Down
131 changes: 131 additions & 0 deletions apps/demo-nextjs-app-router/app/streaming/audio/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
'use client';

import * as fal from '@fal-ai/serverless-client';
import { useRef, useState } from 'react';
import uuid from 'uuid-random';

fal.config({
proxyUrl: '/api/fal/proxy',
});

type PlayHTInput = {
text: string;
request_id: string;
};

const DEFAULT_PROMPT =
"As she sat watching the world go by, something caught her eye. It wasn't so much its color or shape, but the way it was moving. She squinted to see if she could better understand what it was and where it was going, but it didn't help. As she continued to stare into the distance, she didn't understand why this uneasiness was building inside her body.";

export default function AudioStreamingDemo() {
const [prompt, setPrompt] = useState<string>(DEFAULT_PROMPT);
const [streamStatus, setStreamStatus] = useState<string>('idle');
const [timeToFirstChunk, setTimeToFirstChunk] = useState<number | null>(null);

const audioRef = useRef<HTMLAudioElement | null>(null);
const mediaSourceRef = useRef<MediaSource | null>(null);
const sourceBufferRef = useRef<SourceBuffer | null>(null);

const setupMediaSource = () => {
if (!audioRef.current) {
console.warn('Audio element not found or not ready');
return null;
}
const mediaSource = new MediaSource();
mediaSourceRef.current = mediaSource;
const url = URL.createObjectURL(mediaSource);
audioRef.current.src = url;
mediaSource.addEventListener('sourceopen', () => {
const sourceBuffer = mediaSource.addSourceBuffer('audio/mpeg');
sourceBufferRef.current = sourceBuffer;
});
return url;
};

const runInference = async () => {
setupMediaSource();
setTimeToFirstChunk(null);
const startedAt = Date.now();
const stream = await fal.stream<PlayHTInput, Uint8Array>(
'fal-ai/playht-tts',
{
input: {
text: prompt,
request_id: uuid(),
},
accept: 'audio/*',
connectionMode: 'client',
}
);
setStreamStatus('running');
let firstChunk = true;

stream.on('data', (data: Uint8Array) => {
if (audioRef.current?.paused) {
audioRef.current?.play();
}
if (firstChunk) {
setTimeToFirstChunk(Date.now() - startedAt);
firstChunk = false;
}
const sourceBuffer = sourceBufferRef.current;

if (sourceBuffer) {
sourceBuffer.appendBuffer(data);
} else {
console.warn('Source buffer not found or not ready');
}
});

await stream.done();
setStreamStatus('done');
sourceBufferRef.current?.addEventListener('updateend', () => {
mediaSourceRef.current?.endOfStream();
});
};

return (
<div className="min-h-screen dark:bg-gray-900 bg-gray-100">
<main className="container dark:text-gray-50 text-gray-900 flex flex-col items-center justify-center w-full flex-1 py-10 space-y-8">
<h1 className="text-2xl font-bold mb-8">
Hello <code className="text-pink-600">fal</code> +{' '}
<code className="text-indigo-500">streaming</code>
</h1>

<div className="flex flex-col space-y-2 flex-1 w-full">
<textarea
className="flex-1 p-2 rounded bg-black/10 dark:bg-white/5 border border-black/20 dark:border-white/10 text-sm"
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
placeholder="Prompt"
rows={4}
></textarea>
<button
onClick={() => runInference().catch(console.error)}
className="bg-indigo-600 hover:bg-indigo-700 text-white font-bold py-3 px-6 mx-auto rounded focus:outline-none focus:shadow-outline disabled:opacity-70"
disabled={streamStatus === 'running'}
>
Run inference
</button>
</div>

<div className="w-full flex flex-col space-y-4">
<div className="flex flex-row items-center justify-between">
<h2 className="text-2xl font-bold">Result</h2>
<div className="space-x-4">
<span>
time to first chunk:{' '}
<code className="font-semibold">
{timeToFirstChunk ? `${timeToFirstChunk}ms` : 'n/a'}
</code>
</span>
<span>
streaming: <code className="font-semibold">{streamStatus}</code>
</span>
</div>
</div>
<audio controls className="w-full" ref={audioRef} />
</div>
</main>
</div>
);
}
2 changes: 1 addition & 1 deletion libs/client/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client",
"version": "0.14.0-alpha.3",
"version": "0.14.0-alpha.8",
"license": "MIT",
"repository": {
"type": "git",
Expand Down
13 changes: 0 additions & 13 deletions libs/client/src/function.spec.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
import uuid from 'uuid-random';
import { buildUrl } from './function';

describe('The function test suite', () => {
it('should build the URL with a function UUIDv4', () => {
const id = uuid();
const url = buildUrl(`12345/${id}`);
expect(url).toMatch(`trigger/12345/${id}`);
});

it('should build the URL with a function user-id-app-alias', () => {
const alias = '12345-some-alias';
const url = buildUrl(alias);
expect(url).toMatch(`fal.run/12345/some-alias`);
});

it('should build the URL with a function username/app-alias', () => {
const alias = 'fal-ai/text-to-image';
const url = buildUrl(alias);
Expand Down
54 changes: 33 additions & 21 deletions libs/client/src/function.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import { getTemporaryAuthToken } from './auth';
import { dispatchRequest } from './request';
import { storageImpl } from './storage';
import { FalStream } from './streaming';
import { FalStream, StreamingConnectionMode } from './streaming';
import {
CompletedQueueStatus,
EnqueueResult,
QueueStatus,
RequestLog,
} from './types';
import { ensureAppIdFormat, isUUIDv4, isValidUrl, parseAppId } from './utils';
import { ensureAppIdFormat, isValidUrl, parseAppId } from './utils';

/**
* The function input and other configuration when running
Expand Down Expand Up @@ -80,20 +79,13 @@ export function buildUrl<Input>(
Object.keys(params).length > 0
? `?${new URLSearchParams(params).toString()}`
: '';
const parts = id.split('/');

// if a fal url is passed, just use it
if (isValidUrl(id)) {
const url = id.endsWith('/') ? id : `${id}/`;
return `${url}${path}${queryParams}`;
}

// TODO remove this after some time, fal.run should be preferred
if (parts.length === 2 && isUUIDv4(parts[1])) {
const host = 'gateway.shark.fal.ai';
return `https://${host}/trigger/${id}/${path}${queryParams}`;
}

const appId = ensureAppIdFormat(id);
const subdomain = options.subdomain ? `${options.subdomain}.` : '';
const url = `https://${subdomain}fal.run/${appId}/${path}`;
Expand Down Expand Up @@ -199,6 +191,12 @@ type QueueSubscribeOptions = {
}
| {
mode: 'streaming';

/**
* The connection mode to use for streaming updates. It defaults to `server`.
* Set to `client` if your server proxy doesn't support streaming.
*/
connectionMode?: StreamingConnectionMode;
}
);

Expand Down Expand Up @@ -228,6 +226,14 @@ type QueueStatusOptions = BaseQueueOptions & {
logs?: boolean;
};

type QueueStatusStreamOptions = QueueStatusOptions & {
/**
* The connection mode to use for streaming updates. It defaults to `server`.
* Set to `client` if your server proxy doesn't support streaming.
*/
connectionMode?: StreamingConnectionMode;
};

/**
* Represents a request queue with methods for submitting requests,
* checking their status, retrieving results, and subscribing to updates.
Expand Down Expand Up @@ -263,7 +269,7 @@ interface Queue {
*/
streamStatus(
endpointId: string,
options: QueueStatusOptions
options: QueueStatusStreamOptions
): Promise<FalStream<unknown, QueueStatus>>;

/**
Expand Down Expand Up @@ -340,24 +346,26 @@ export const queue: Queue = {

async streamStatus(
endpointId: string,
{ requestId, logs = false }: QueueStatusOptions
{ requestId, logs = false, connectionMode }: QueueStatusStreamOptions
): Promise<FalStream<unknown, QueueStatus>> {
const appId = parseAppId(endpointId);
const prefix = appId.namespace ? `${appId.namespace}/` : '';
const token = await getTemporaryAuthToken(endpointId);

const queryParams = {
logs: logs ? '1' : '0',
};

const url = buildUrl(`${prefix}${appId.owner}/${appId.alias}`, {
subdomain: 'queue',
path: `/requests/${requestId}/status/stream`,
query: queryParams,
});

const queryParams = new URLSearchParams({
fal_jwt_token: token,
logs: logs ? '1' : '0',
});

return new FalStream<unknown, QueueStatus>(`${url}?${queryParams}`, {
input: {},
return new FalStream<unknown, QueueStatus>(endpointId, {
url,
method: 'get',
connectionMode,
queryParams,
});
},

Expand All @@ -375,6 +383,10 @@ export const queue: Queue = {
const status = await queue.streamStatus(endpointId, {
requestId,
logs: options.logs,
connectionMode:
'connectionMode' in options
? (options.connectionMode as StreamingConnectionMode)
: undefined,
});
const logs: RequestLog[] = [];
if (timeout) {
Expand All @@ -390,7 +402,7 @@ export const queue: Queue = {
);
}, timeout);
}
status.on('message', (data: QueueStatus) => {
status.on('data', (data: QueueStatus) => {
if (options.onQueueUpdate) {
// accumulate logs to match previous polling behavior
if (
Expand Down
19 changes: 16 additions & 3 deletions libs/client/src/request.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import { getConfig } from './config';
import { ResponseHandler } from './response';
import { getUserAgent, isBrowser } from './runtime';

const isCloudflareWorkers =
typeof navigator !== 'undefined' &&
navigator?.userAgent === 'Cloudflare-Workers';

type RequestOptions = {
responseHandler?: ResponseHandler<any>;

Check warning on line 10 in libs/client/src/request.ts

View workflow job for this annotation

GitHub Actions / build

Unexpected any. Specify a different type
};

export async function dispatchRequest<Input, Output>(
method: string,
targetUrl: string,
input: Input
input: Input,
options: RequestOptions & RequestInit = {}
): Promise<Output> {
const {
credentials: credentialsValue,
Expand Down Expand Up @@ -39,14 +45,21 @@
...userAgent,
...(headers ?? {}),
} as HeadersInit;

const { responseHandler: customResponseHandler, ...requestInit } = options;
const response = await fetch(url, {
...requestInit,
method,
headers: requestHeaders,
headers: {
...requestHeaders,
...(requestInit.headers ?? {}),
},
...(!isCloudflareWorkers && { mode: 'cors' }),
body:
method.toLowerCase() !== 'get' && input
? JSON.stringify(input)
: undefined,
});
return await responseHandler(response);
const handleResponse = customResponseHandler ?? responseHandler;
return await handleResponse(response);
}
Loading
Loading