Skip to content

Commit

Permalink
feat: client file upload (#18)
Browse files Browse the repository at this point in the history
* feat: client file upload

* feat: signed upload

* fix: add multipart header

* fix: remove unused file

* chore(wip): own signature impl

* feat: use gcs presigned upload

* fix: invalid export

* fix: rest api host url

* feat: final upload logic and sample
  • Loading branch information
drochetti authored Nov 3, 2023
1 parent 3a98fd6 commit 78ffa58
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 63 deletions.
63 changes: 42 additions & 21 deletions apps/demo-nextjs-app-router/app/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Image = {
file_size: number;
};
type Result = {
images: Image[];
image: Image;
};
// @snippet:end

Expand All @@ -42,12 +42,13 @@ function Error(props: ErrorProps) {
}

const DEFAULT_PROMPT =
'a city landscape of a cyberpunk metropolis, raining, purple, pink and teal neon lights, highly detailed, uhd';
'(masterpiece:1.4), (best quality), (detailed), Medieval village scene with busy streets and castle in the distance';

export default function Home() {
// @snippet:start("client.ui.state")
// Input state
const [prompt, setPrompt] = useState<string>(DEFAULT_PROMPT);
const [imageFile, setImageFile] = useState<File | null>(null);
// Result state
const [loading, setLoading] = useState(false);
const [error, setError] = useState<Error | null>(null);
Expand All @@ -59,7 +60,10 @@ export default function Home() {
if (!result) {
return null;
}
return result.images[0];
if (result.image) {
return result.image;
}
return null;
}, [result]);

const reset = () => {
Expand All @@ -76,24 +80,27 @@ export default function Home() {
setLoading(true);
const start = Date.now();
try {
const result: Result = await fal.subscribe('110602490-lora', {
input: {
prompt,
model_name: 'stabilityai/stable-diffusion-xl-base-1.0',
image_size: 'square_hd',
},
pollInterval: 5000, // Default is 1000 (every 1s)
logs: true,
onQueueUpdate(update) {
setElapsedTime(Date.now() - start);
if (
update.status === 'IN_PROGRESS' ||
update.status === 'COMPLETED'
) {
setLogs((update.logs || []).map((log) => log.message));
}
},
});
const result: Result = await fal.subscribe(
'54285744-illusion-diffusion',
{
input: {
prompt,
image_url: imageFile,
image_size: 'square_hd',
},
pollInterval: 5000, // Default is 1000 (every 1s)
logs: true,
onQueueUpdate(update) {
setElapsedTime(Date.now() - start);
if (
update.status === 'IN_PROGRESS' ||
update.status === 'COMPLETED'
) {
setLogs((update.logs || []).map((log) => log.message));
}
},
}
);
setResult(result);
} catch (error: any) {
setError(error);
Expand All @@ -109,6 +116,20 @@ export default function Home() {
<h1 className="text-4xl font-bold mb-8">
Hello <code className="font-light text-pink-600">fal</code>
</h1>
<div className="text-lg w-full">
<label htmlFor="prompt" className="block mb-2 text-current">
Image
</label>
<input
className="w-full text-lg p-2 rounded bg-black/10 dark:bg-white/5 border border-black/20 dark:border-white/10"
id="image_url"
name="image_url"
type="file"
placeholder="Choose a file"
accept="image/*"
onChange={(e) => setImageFile(e.target.files?.[0] ?? null)}
/>
</div>
<div className="text-lg w-full">
<label htmlFor="prompt" className="block mb-2 text-current">
Prompt
Expand Down
2 changes: 1 addition & 1 deletion libs/client/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client",
"version": "0.4.2",
"version": "0.5.0",
"license": "MIT",
"repository": {
"type": "git",
Expand Down
51 changes: 10 additions & 41 deletions libs/client/src/function.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { getConfig } from './config';
import { getUserAgent, isBrowser } from './runtime';
import { storageImpl } from './storage';
import { dispatchRequest } from './request';
import { EnqueueResult, QueueStatus } from './types';
import { isUUIDv4, isValidUrl } from './utils';

Expand Down Expand Up @@ -62,7 +63,6 @@ export function buildUrl<Input>(

/**
* Runs a fal serverless function identified by its `id`.
* TODO: expand documentation and provide examples
*
* @param id the registered function revision id or alias.
* @returns the remote function output
Expand All @@ -71,45 +71,14 @@ export async function run<Input, Output>(
id: string,
options: RunOptions<Input> = {}
): Promise<Output> {
const {
credentials: credentialsValue,
requestMiddleware,
responseHandler,
} = getConfig();
const method = (options.method ?? 'post').toLowerCase();
const userAgent = isBrowser() ? {} : { 'User-Agent': getUserAgent() };
const credentials =
typeof credentialsValue === 'function'
? credentialsValue()
: credentialsValue;

const { url, headers } = await requestMiddleware({
url: buildUrl(id, options),
});
const authHeader = credentials ? { Authorization: `Key ${credentials}` } : {};
if (typeof window !== 'undefined' && credentials) {
console.warn(
"The fal credentials are exposed in the browser's environment. " +
"That's not recommended for production use cases."
);
}
const requestHeaders = {
...authHeader,
Accept: 'application/json',
'Content-Type': 'application/json',
...userAgent,
...(headers ?? {}),
} as HeadersInit;
const response = await fetch(url, {
method,
headers: requestHeaders,
mode: 'cors',
body:
method !== 'get' && options.input
? JSON.stringify(options.input)
: undefined,
});
return await responseHandler(response);
const input = options.input
? await storageImpl.transformInput(options.input)
: options.input;
return dispatchRequest<Input, Output>(
options.method ?? 'post',
buildUrl(id, options),
input as Input
);
}

/**
Expand Down
1 change: 1 addition & 0 deletions libs/client/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export { config, getConfig } from './config';
export { storageImpl as storage } from './storage';
export { queue, run, subscribe } from './function';
export { withMiddleware, withProxy } from './middleware';
export type { RequestMiddleware } from './middleware';
Expand Down
47 changes: 47 additions & 0 deletions libs/client/src/request.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import { getConfig } from './config';
import { getUserAgent, isBrowser } from './runtime';

export async function dispatchRequest<Input, Output>(
method: string,
targetUrl: string,
input: Input
): Promise<Output> {
const {
credentials: credentialsValue,
requestMiddleware,
responseHandler,
} = getConfig();
const userAgent = isBrowser() ? {} : { 'User-Agent': getUserAgent() };
const credentials =
typeof credentialsValue === 'function'
? credentialsValue()
: credentialsValue;

const { url, headers } = await requestMiddleware({
url: targetUrl,
});
const authHeader = credentials ? { Authorization: `Key ${credentials}` } : {};
if (typeof window !== 'undefined' && credentials) {
console.warn(
"The fal credentials are exposed in the browser's environment. " +
"That's not recommended for production use cases."
);
}
const requestHeaders = {
...authHeader,
Accept: 'application/json',
'Content-Type': 'application/json',
...userAgent,
...(headers ?? {}),
} as HeadersInit;
const response = await fetch(url, {
method,
headers: requestHeaders,
mode: 'cors',
body:
method.toLowerCase() !== 'get' && input
? JSON.stringify(input)
: undefined,
});
return await responseHandler(response);
}
108 changes: 108 additions & 0 deletions libs/client/src/storage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import { getConfig } from './config';
import { dispatchRequest } from './request';

/**
* File support for the client. This interface establishes the contract for
* uploading files to the server and transforming the input to replace file
* objects with URLs.
*/
export interface StorageSupport {
/**
* Upload a file to the server. Returns the URL of the uploaded file.
* @param file the file to upload
* @param options optional parameters, such as custom file name
* @returns the URL of the uploaded file
*/
upload: (file: Blob) => Promise<string>;

/**
* Transform the input to replace file objects with URLs. This is used
* to transform the input before sending it to the server and ensures
* that the server receives URLs instead of file objects.
*
* @param input the input to transform.
* @returns the transformed input.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
transformInput: (input: Record<string, any>) => Promise<Record<string, any>>;
}

function isDataUri(uri: string): boolean {
// avoid uri parsing if it doesn't start with data:
if (!uri.startsWith('data:')) {
return false;
}
try {
const url = new URL(uri);
return url.protocol === 'data:';
} catch (_) {
return false;
}
}

type InitiateUploadResult = {
file_url: string;
upload_url: string;
};

type InitiateUploadData = {
file_name: string;
content_type: string | null;
};

function getRestApiUrl(): string {
const { host } = getConfig();
return host.replace('gateway', 'rest');
}

async function initiateUpload(file: Blob): Promise<InitiateUploadResult> {
return await dispatchRequest<InitiateUploadData, InitiateUploadResult>(
'POST',
`https://${getRestApiUrl()}/storage/upload/initiate`,
{
file_name: file.name,
content_type: file.type || 'application/octet-stream',
}
);
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type KeyValuePair = [string, any];

export const storageImpl: StorageSupport = {
upload: async (file: Blob) => {
const { upload_url: uploadUrl, file_url: url } = await initiateUpload(file);
const response = await fetch(uploadUrl, {
method: 'PUT',
body: file,
headers: {
'Content-Type': file.type || 'application/octet-stream',
},
});
const { responseHandler } = getConfig();
await responseHandler(response);
return url;
},

// eslint-disable-next-line @typescript-eslint/no-explicit-any
transformInput: async (input: Record<string, any>) => {
const promises = Object.entries(input).map(async ([key, value]) => {
if (
value instanceof Blob ||
(typeof value === 'string' && isDataUri(value))
) {
let blob = value;
// if string is a data uri, convert to blob
if (typeof value === 'string' && isDataUri(value)) {
const response = await fetch(value);
blob = await response.blob();
}
const url = await storageImpl.upload(blob as Blob);
return [key, url];
}
return [key, value] as KeyValuePair;
});
const results = await Promise.all(promises);
return Object.fromEntries(results);
},
};

0 comments on commit 78ffa58

Please sign in to comment.