Skip to content

Commit

Permalink
feat(client): binary chunk streaming support (#77)
Browse files Browse the repository at this point in the history
* feat(client): binary chunk streaming support

* fix: buffer update

* fix: audio playing

* fix: media resource setup

* feat: allow streaming through the proxy

* fix: legacy urls env

* feat: streaming connection mode

* chore: bump client alpha version

* feat(proxy): enable response streaming when supported

* fix: queue streaming

* chore: deprecated endpoint id cleanup

* fix: client tests

* chore: demo page updates

* chore: bump version for release
  • Loading branch information
drochetti authored Aug 6, 2024
1 parent d9ea6c7 commit 6edbf29
Show file tree
Hide file tree
Showing 15 changed files with 281 additions and 185 deletions.
9 changes: 7 additions & 2 deletions apps/demo-nextjs-app-router/app/queue/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ function Error(props: ErrorProps) {
);
}

const DEFAULT_ENDPOINT_ID = 'fal-ai/fast-sdxl';
const DEFAULT_INPUT = `{
"prompt": "A beautiful sunset over the ocean"
}`;

export default function Home() {
// Input state
const [endpointId, setEndpointId] = useState<string>('');
const [input, setInput] = useState<string>('{}');
const [endpointId, setEndpointId] = useState<string>(DEFAULT_ENDPOINT_ID);
const [input, setInput] = useState<string>(DEFAULT_INPUT);
// Result state
const [loading, setLoading] = useState(false);
const [error, setError] = useState<Error | null>(null);
Expand Down
2 changes: 1 addition & 1 deletion libs/client/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client",
"version": "0.14.0-alpha.3",
"version": "0.14.0",
"license": "MIT",
"repository": {
"type": "git",
Expand Down
13 changes: 0 additions & 13 deletions libs/client/src/function.spec.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
import uuid from 'uuid-random';
import { buildUrl } from './function';

describe('The function test suite', () => {
it('should build the URL with a function UUIDv4', () => {
const id = uuid();
const url = buildUrl(`12345/${id}`);
expect(url).toMatch(`trigger/12345/${id}`);
});

it('should build the URL with a function user-id-app-alias', () => {
const alias = '12345-some-alias';
const url = buildUrl(alias);
expect(url).toMatch(`fal.run/12345/some-alias`);
});

it('should build the URL with a function username/app-alias', () => {
const alias = 'fal-ai/text-to-image';
const url = buildUrl(alias);
Expand Down
54 changes: 33 additions & 21 deletions libs/client/src/function.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import { getTemporaryAuthToken } from './auth';
import { dispatchRequest } from './request';
import { storageImpl } from './storage';
import { FalStream } from './streaming';
import { FalStream, StreamingConnectionMode } from './streaming';
import {
CompletedQueueStatus,
EnqueueResult,
QueueStatus,
RequestLog,
} from './types';
import { ensureAppIdFormat, isUUIDv4, isValidUrl, parseAppId } from './utils';
import { ensureAppIdFormat, isValidUrl, parseAppId } from './utils';

/**
* The function input and other configuration when running
Expand Down Expand Up @@ -80,20 +79,13 @@ export function buildUrl<Input>(
Object.keys(params).length > 0
? `?${new URLSearchParams(params).toString()}`
: '';
const parts = id.split('/');

// if a fal url is passed, just use it
if (isValidUrl(id)) {
const url = id.endsWith('/') ? id : `${id}/`;
return `${url}${path}${queryParams}`;
}

// TODO remove this after some time, fal.run should be preferred
if (parts.length === 2 && isUUIDv4(parts[1])) {
const host = 'gateway.shark.fal.ai';
return `https://${host}/trigger/${id}/${path}${queryParams}`;
}

const appId = ensureAppIdFormat(id);
const subdomain = options.subdomain ? `${options.subdomain}.` : '';
const url = `https://${subdomain}fal.run/${appId}/${path}`;
Expand Down Expand Up @@ -199,6 +191,12 @@ type QueueSubscribeOptions = {
}
| {
mode: 'streaming';

/**
* The connection mode to use for streaming updates. It defaults to `server`.
* Set to `client` if your server proxy doesn't support streaming.
*/
connectionMode?: StreamingConnectionMode;
}
);

Expand Down Expand Up @@ -228,6 +226,14 @@ type QueueStatusOptions = BaseQueueOptions & {
logs?: boolean;
};

type QueueStatusStreamOptions = QueueStatusOptions & {
/**
* The connection mode to use for streaming updates. It defaults to `server`.
* Set to `client` if your server proxy doesn't support streaming.
*/
connectionMode?: StreamingConnectionMode;
};

/**
* Represents a request queue with methods for submitting requests,
* checking their status, retrieving results, and subscribing to updates.
Expand Down Expand Up @@ -263,7 +269,7 @@ interface Queue {
*/
streamStatus(
endpointId: string,
options: QueueStatusOptions
options: QueueStatusStreamOptions
): Promise<FalStream<unknown, QueueStatus>>;

/**
Expand Down Expand Up @@ -340,24 +346,26 @@ export const queue: Queue = {

async streamStatus(
endpointId: string,
{ requestId, logs = false }: QueueStatusOptions
{ requestId, logs = false, connectionMode }: QueueStatusStreamOptions
): Promise<FalStream<unknown, QueueStatus>> {
const appId = parseAppId(endpointId);
const prefix = appId.namespace ? `${appId.namespace}/` : '';
const token = await getTemporaryAuthToken(endpointId);

const queryParams = {
logs: logs ? '1' : '0',
};

const url = buildUrl(`${prefix}${appId.owner}/${appId.alias}`, {
subdomain: 'queue',
path: `/requests/${requestId}/status/stream`,
query: queryParams,
});

const queryParams = new URLSearchParams({
fal_jwt_token: token,
logs: logs ? '1' : '0',
});

return new FalStream<unknown, QueueStatus>(`${url}?${queryParams}`, {
input: {},
return new FalStream<unknown, QueueStatus>(endpointId, {
url,
method: 'get',
connectionMode,
queryParams,
});
},

Expand All @@ -375,6 +383,10 @@ export const queue: Queue = {
const status = await queue.streamStatus(endpointId, {
requestId,
logs: options.logs,
connectionMode:
'connectionMode' in options
? (options.connectionMode as StreamingConnectionMode)
: undefined,
});
const logs: RequestLog[] = [];
if (timeout) {
Expand All @@ -390,7 +402,7 @@ export const queue: Queue = {
);
}, timeout);
}
status.on('message', (data: QueueStatus) => {
status.on('data', (data: QueueStatus) => {
if (options.onQueueUpdate) {
// accumulate logs to match previous polling behavior
if (
Expand Down
19 changes: 16 additions & 3 deletions libs/client/src/request.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import { getConfig } from './config';
import { ResponseHandler } from './response';
import { getUserAgent, isBrowser } from './runtime';

const isCloudflareWorkers =
typeof navigator !== 'undefined' &&
navigator?.userAgent === 'Cloudflare-Workers';

type RequestOptions = {
responseHandler?: ResponseHandler<any>;
};

export async function dispatchRequest<Input, Output>(
method: string,
targetUrl: string,
input: Input
input: Input,
options: RequestOptions & RequestInit = {}
): Promise<Output> {
const {
credentials: credentialsValue,
Expand Down Expand Up @@ -39,14 +45,21 @@ export async function dispatchRequest<Input, Output>(
...userAgent,
...(headers ?? {}),
} as HeadersInit;

const { responseHandler: customResponseHandler, ...requestInit } = options;
const response = await fetch(url, {
...requestInit,
method,
headers: requestHeaders,
headers: {
...requestHeaders,
...(requestInit.headers ?? {}),
},
...(!isCloudflareWorkers && { mode: 'cors' }),
body:
method.toLowerCase() !== 'get' && input
? JSON.stringify(input)
: undefined,
});
return await responseHandler(response);
const handleResponse = customResponseHandler ?? responseHandler;
return await handleResponse(response);
}
Loading

0 comments on commit 6edbf29

Please sign in to comment.